mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
(doc/fluxion/chain) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
77b97b3c8e
commit
c7fd1496b5
|
@ -16,24 +16,14 @@ T = TypeVar("T", bound=Module)
|
|||
TChain = TypeVar("TChain", bound="Chain") # because Self (PEP 673) is not in 3.10
|
||||
|
||||
|
||||
class Lambda(Module):
|
||||
"""Lambda is a wrapper around a callable object that allows it to be used as a PyTorch module."""
|
||||
|
||||
def __init__(self, func: Callable[..., Any]) -> None:
|
||||
super().__init__()
|
||||
self.func = func
|
||||
|
||||
def forward(self, *args: Any) -> Any:
|
||||
return self.func(*args)
|
||||
|
||||
def __str__(self) -> str:
|
||||
func_name = getattr(self.func, "__name__", "partial_function")
|
||||
return f"Lambda({func_name}{str(inspect.signature(self.func))})"
|
||||
|
||||
|
||||
def generate_unique_names(
|
||||
modules: tuple[Module, ...],
|
||||
) -> dict[str, Module]:
|
||||
"""Generate unique names for each Module in a sequence.
|
||||
|
||||
Args:
|
||||
modules: The sequence of Modules to name.
|
||||
"""
|
||||
class_counts: dict[str, int] = {}
|
||||
unique_names: list[tuple[str, Module]] = []
|
||||
for module in modules:
|
||||
|
@ -48,69 +38,8 @@ def generate_unique_names(
|
|||
return dict(unique_names)
|
||||
|
||||
|
||||
class UseContext(ContextModule):
|
||||
def __init__(self, context: str, key: str) -> None:
|
||||
super().__init__()
|
||||
self.context = context
|
||||
self.key = key
|
||||
self.func: Callable[[Any], Any] = lambda x: x
|
||||
|
||||
def __call__(self, *args: Any) -> Any:
|
||||
context = self.use_context(self.context)
|
||||
assert context, f"context {self.context} is unset"
|
||||
value = context.get(self.key)
|
||||
assert value is not None, f"context entry {self.context}.{self.key} is unset"
|
||||
return self.func(value)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
|
||||
|
||||
def compose(self, func: Callable[[Any], Any]) -> "UseContext":
|
||||
self.func = func
|
||||
return self
|
||||
|
||||
|
||||
class SetContext(ContextModule):
|
||||
"""A Module that sets a context value when executed.
|
||||
|
||||
The context need to pre exist in the context provider.
|
||||
#TODO Is there a way to create the context if it doesn't exist?
|
||||
"""
|
||||
|
||||
def __init__(self, context: str, key: str, callback: Callable[[Any, Any], Any] | None = None) -> None:
|
||||
super().__init__()
|
||||
self.context = context
|
||||
self.key = key
|
||||
self.callback = callback
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
if context := self.use_context(self.context):
|
||||
if not self.callback:
|
||||
context.update({self.key: x})
|
||||
else:
|
||||
self.callback(context[self.key], x)
|
||||
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
"""Exception raised when a Return module is encountered."""
|
||||
|
||||
def __init__(self, value: Tensor):
|
||||
self.value = value
|
||||
|
||||
|
||||
class Return(Module):
|
||||
"""A Module that stops the execution of a Chain when encountered."""
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
raise ReturnException(x)
|
||||
|
||||
|
||||
def structural_copy(m: T) -> T:
|
||||
"""Helper function to copy a Module's tree, only if it is a ContextModule instance."""
|
||||
return m.structural_copy() if isinstance(m, ContextModule) else m
|
||||
|
||||
|
||||
|
@ -122,6 +51,29 @@ class ChainError(RuntimeError):
|
|||
|
||||
|
||||
class Chain(ContextModule):
|
||||
"""Chain layer.
|
||||
|
||||
This layer is the main building block of Fluxion.
|
||||
It is used to compose other layers in a sequential manner.
|
||||
Similary to [`torch.nn.Sequential`][torch.nn.Sequential],
|
||||
it calls each of its sub-layers in order, chaining their outputs as inputs to the next sublayer.
|
||||
However, it also provides additional methods to manipulate its sub-layers and their context.
|
||||
|
||||
Example:
|
||||
```py
|
||||
chain = fl.Chain(
|
||||
fl.Linear(32, 64),
|
||||
fl.ReLU(),
|
||||
fl.Linear(64, 128),
|
||||
)
|
||||
|
||||
tensor = torch.randn(2, 32)
|
||||
output = chain(tensor)
|
||||
|
||||
assert output.shape == (2, 128)
|
||||
```
|
||||
"""
|
||||
|
||||
_modules: dict[str, Module]
|
||||
_provider: ContextProvider
|
||||
_tag = "CHAIN"
|
||||
|
@ -165,12 +117,23 @@ class Chain(ContextModule):
|
|||
|
||||
@property
|
||||
def provider(self) -> ContextProvider:
|
||||
"""The [`ContextProvider`][refiners.fluxion.context.ContextProvider] of the Chain."""
|
||||
return self._provider
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
"""Initialize the context provider with some default values.
|
||||
|
||||
This method is called when the Chain is created, and when it is reset.
|
||||
This method may be overridden by subclasses to provide default values for the context provider.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def _register_provider(self, context: Contexts | None = None) -> None:
|
||||
def _register_provider(self, context: Contexts | None = None) -> None: # TODO: rename me ?
|
||||
"""Recursively update the context provider to all sub-modules.
|
||||
|
||||
Args:
|
||||
context: The context to use to update the provider.
|
||||
"""
|
||||
if context:
|
||||
self._provider.update_contexts(context)
|
||||
|
||||
|
@ -179,9 +142,16 @@ class Chain(ContextModule):
|
|||
module._register_provider(context=self._provider.contexts)
|
||||
|
||||
def _reset_context(self) -> None:
|
||||
"""Reset the context provider to its initial state."""
|
||||
self._register_provider(self.init_context())
|
||||
|
||||
def set_context(self, context: str, value: Any) -> None:
|
||||
"""Set a value in the context provider.
|
||||
|
||||
Args:
|
||||
context: The context to update.
|
||||
value: The value to set.
|
||||
"""
|
||||
self._provider.set_context(context, value)
|
||||
self._register_provider()
|
||||
|
||||
|
@ -315,18 +285,27 @@ class Chain(ContextModule):
|
|||
|
||||
@property
|
||||
def device(self) -> Device | None:
|
||||
"""The PyTorch device of the Chain's parameters."""
|
||||
wm = self.find(WeightedModule)
|
||||
return None if wm is None else wm.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType | None:
|
||||
"""The PyTorch dtype of the Chain's parameters."""
|
||||
wm = self.find(WeightedModule)
|
||||
return None if wm is None else wm.dtype
|
||||
|
||||
def _walk(
|
||||
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||
self,
|
||||
predicate: Callable[[Module, "Chain"], bool] | None = None,
|
||||
recurse: bool = False,
|
||||
) -> Iterator[tuple[Module, "Chain"]]:
|
||||
"""Walk the Chain's sub-module tree and yield each module that matches the predicate.
|
||||
|
||||
The predicate is a (Module, Chain) -> bool function.
|
||||
"""
|
||||
if predicate is None:
|
||||
# if no predicate is given, yield all modules
|
||||
predicate = lambda _m, _p: True
|
||||
for module in self:
|
||||
try:
|
||||
|
@ -342,35 +321,100 @@ class Chain(ContextModule):
|
|||
|
||||
@overload
|
||||
def walk(
|
||||
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||
self,
|
||||
predicate: Callable[[Module, "Chain"], bool] | None = None,
|
||||
recurse: bool = False,
|
||||
) -> Iterator[tuple[Module, "Chain"]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]:
|
||||
def walk(
|
||||
self,
|
||||
predicate: type[T],
|
||||
recurse: bool = False,
|
||||
) -> Iterator[tuple[T, "Chain"]]:
|
||||
...
|
||||
|
||||
def walk(
|
||||
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||
self,
|
||||
predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None,
|
||||
recurse: bool = False,
|
||||
) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]:
|
||||
if isinstance(predicate, type):
|
||||
return self._walk(lambda m, _: isinstance(m, predicate), recurse)
|
||||
else:
|
||||
return self._walk(predicate, recurse)
|
||||
"""Walk the Chain's sub-module tree and yield each module that matches the predicate.
|
||||
|
||||
def layers(self, layer_type: type[T], recurse: bool = False) -> Iterator[T]:
|
||||
Args:
|
||||
predicate: The predicate to match.
|
||||
recurse: Whether to recurse into sub-Chains.
|
||||
|
||||
Yields:
|
||||
Each module that matches the predicate.
|
||||
"""
|
||||
if isinstance(predicate, type):
|
||||
# if the predicate is a Module type
|
||||
# build a predicate function that matches the type
|
||||
return self._walk(
|
||||
predicate=lambda m, _: isinstance(m, predicate),
|
||||
recurse=recurse,
|
||||
)
|
||||
else:
|
||||
return self._walk(
|
||||
predicate=predicate,
|
||||
recurse=recurse,
|
||||
)
|
||||
|
||||
def layers(
|
||||
self,
|
||||
layer_type: type[T],
|
||||
recurse: bool = False,
|
||||
) -> Iterator[T]:
|
||||
"""Walk the Chain's sub-module tree and yield each layer of the given type.
|
||||
|
||||
Args:
|
||||
layer_type: The type of layer to yield.
|
||||
recurse: Whether to recurse into sub-Chains.
|
||||
|
||||
Yields:
|
||||
Each module of the given layer_type.
|
||||
"""
|
||||
for module, _ in self.walk(layer_type, recurse):
|
||||
yield module
|
||||
|
||||
def find(self, layer_type: type[T]) -> T | None:
|
||||
"""Walk the Chain's sub-module tree and return the first layer of the given type.
|
||||
|
||||
Args:
|
||||
layer_type: The type of layer to find.
|
||||
|
||||
Returns:
|
||||
The first module of the given layer_type, or None if it doesn't exist.
|
||||
"""
|
||||
return next(self.layers(layer_type=layer_type), None)
|
||||
|
||||
def ensure_find(self, layer_type: type[T]) -> T:
|
||||
"""Walk the Chain's sub-module tree and return the first layer of the given type.
|
||||
|
||||
Args:
|
||||
layer_type: The type of layer to find.
|
||||
|
||||
Returns:
|
||||
The first module of the given layer_type.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the module doesn't exist.
|
||||
"""
|
||||
r = self.find(layer_type)
|
||||
assert r is not None, f"could not find {layer_type} in {self}"
|
||||
return r
|
||||
|
||||
def find_parent(self, module: Module) -> "Chain | None":
|
||||
"""Walk the Chain's sub-module tree and return the parent of the given module.
|
||||
|
||||
Args:
|
||||
module: The module whose parent to find.
|
||||
|
||||
Returns:
|
||||
The parent of the given module, or None if it doesn't exist.
|
||||
"""
|
||||
if module in self: # avoid DFS-crawling the whole tree
|
||||
return self
|
||||
for _, parent in self.walk(lambda m, _: m == module):
|
||||
|
@ -378,11 +422,31 @@ class Chain(ContextModule):
|
|||
return None
|
||||
|
||||
def ensure_find_parent(self, module: Module) -> "Chain":
|
||||
"""Walk the Chain's sub-module tree and return the parent of the given module.
|
||||
|
||||
Args:
|
||||
module: The module whose parent to find.
|
||||
|
||||
Returns:
|
||||
The parent of the given module.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the module doesn't exist.
|
||||
"""
|
||||
r = self.find_parent(module)
|
||||
assert r is not None, f"could not find {module} in {self}"
|
||||
return r
|
||||
|
||||
def insert(self, index: int, module: Module) -> None:
|
||||
"""Insert a new module in the chain.
|
||||
|
||||
Args:
|
||||
index: The index at which to insert the module.
|
||||
module: The module to insert.
|
||||
|
||||
Raises:
|
||||
IndexError: If the index is out of range.
|
||||
"""
|
||||
if index < 0:
|
||||
index = max(0, len(self._modules) + index + 1)
|
||||
modules = list(self)
|
||||
|
@ -393,6 +457,15 @@ class Chain(ContextModule):
|
|||
self._register_provider()
|
||||
|
||||
def insert_before_type(self, module_type: type[Module], new_module: Module) -> None:
|
||||
"""Insert a new module in the chain, right before the first module of the given type.
|
||||
|
||||
Args:
|
||||
module_type: The type of module to insert before.
|
||||
new_module: The module to insert.
|
||||
|
||||
Raises:
|
||||
ValueError: If no module of the given type exists in the chain.
|
||||
"""
|
||||
for i, module in enumerate(self):
|
||||
if isinstance(module, module_type):
|
||||
self.insert(i, new_module)
|
||||
|
@ -400,6 +473,15 @@ class Chain(ContextModule):
|
|||
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
|
||||
|
||||
def insert_after_type(self, module_type: type[Module], new_module: Module) -> None:
|
||||
"""Insert a new module in the chain, right after the first module of the given type.
|
||||
|
||||
Args:
|
||||
module_type: The type of module to insert after.
|
||||
new_module: The module to insert.
|
||||
|
||||
Raises:
|
||||
ValueError: If no module of the given type exists in the chain.
|
||||
"""
|
||||
for i, module in enumerate(self):
|
||||
if isinstance(module, module_type):
|
||||
self.insert(i + 1, new_module)
|
||||
|
@ -407,9 +489,25 @@ class Chain(ContextModule):
|
|||
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
|
||||
|
||||
def append(self, module: Module) -> None:
|
||||
"""Append a new module to the chain.
|
||||
|
||||
Args:
|
||||
module: The module to append.
|
||||
"""
|
||||
self.insert(-1, module)
|
||||
|
||||
def pop(self, index: int = -1) -> Module | tuple[Module]:
|
||||
"""Pop a module from the chain at the given index.
|
||||
|
||||
Args:
|
||||
index: The index of the module to pop.
|
||||
|
||||
Returns:
|
||||
The popped module.
|
||||
|
||||
Raises:
|
||||
IndexError: If the index is out of range.
|
||||
"""
|
||||
modules = list(self)
|
||||
if index < 0:
|
||||
index = len(modules) + index
|
||||
|
@ -422,7 +520,14 @@ class Chain(ContextModule):
|
|||
return removed_module
|
||||
|
||||
def remove(self, module: Module) -> None:
|
||||
"""Remove a module from the chain."""
|
||||
"""Remove a module from the chain.
|
||||
|
||||
Args:
|
||||
module: The module to remove.
|
||||
|
||||
Raises:
|
||||
ValueError: If the module is not in the chain.
|
||||
"""
|
||||
modules = list(self)
|
||||
try:
|
||||
modules.remove(module)
|
||||
|
@ -438,7 +543,17 @@ class Chain(ContextModule):
|
|||
new_module: Module,
|
||||
old_module_parent: "Chain | None" = None,
|
||||
) -> None:
|
||||
"""Replace a module in the chain with a new module."""
|
||||
"""Replace a module in the chain with a new module.
|
||||
|
||||
Args:
|
||||
old_module: The module to replace.
|
||||
new_module: The module to replace with.
|
||||
old_module_parent: The parent of the old module.
|
||||
If None, the old module is orphanized.
|
||||
|
||||
Raises:
|
||||
ValueError: If the module is not in the chain.
|
||||
"""
|
||||
modules = list(self)
|
||||
try:
|
||||
modules[modules.index(old_module)] = new_module
|
||||
|
@ -479,29 +594,221 @@ class Chain(ContextModule):
|
|||
return self.__class__ == Chain
|
||||
|
||||
|
||||
class UseContext(ContextModule):
|
||||
"""UseContext layer.
|
||||
|
||||
This layer reads from the [`ContextProvider`][refiners.fluxion.context.ContextProvider]
|
||||
of its parent [`Chain`][refiners.fluxion.layers.chain.Chain].
|
||||
|
||||
Note: When called, it will
|
||||
- Retrieve a value from the context using the given key
|
||||
- Transform the value with the given function (optional)
|
||||
- Return the value
|
||||
"""
|
||||
|
||||
def __init__(self, context: str, key: str) -> None:
|
||||
super().__init__()
|
||||
self.context = context
|
||||
self.key = key
|
||||
self.func: Callable[[Any], Any] = lambda x: x
|
||||
|
||||
def __call__(self, *args: Any) -> Any:
|
||||
context = self.use_context(self.context)
|
||||
assert context, f"context {self.context} is unset"
|
||||
value = context.get(self.key)
|
||||
assert value is not None, f"context entry {self.context}.{self.key} is unset"
|
||||
return self.func(value)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
|
||||
|
||||
def compose(self, func: Callable[[Any], Any]) -> "UseContext":
|
||||
self.func = func
|
||||
return self
|
||||
|
||||
|
||||
class SetContext(ContextModule):
|
||||
"""SetContext layer.
|
||||
|
||||
This layer writes to the [`ContextProvider`][refiners.fluxion.context.ContextProvider]
|
||||
of its parent [`Chain`][refiners.fluxion.layers.chain.Chain].
|
||||
|
||||
Note: When called (without a callback), it will
|
||||
- Update the context with the given key and the input value
|
||||
- Return the input value
|
||||
|
||||
Note: When called (with a callback), it will
|
||||
- Call the callback with the current value and the input value
|
||||
(the callback may update the context with a new value, or not)
|
||||
- Return the input value
|
||||
|
||||
Warning:
|
||||
The context needs to already exist in the [`ContextProvider`][refiners.fluxion.context.ContextProvider]
|
||||
"""
|
||||
|
||||
# TODO: Create the context if it doesn't exist
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: str,
|
||||
key: str,
|
||||
callback: Callable[[Any, Any], Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.context = context
|
||||
self.key = key
|
||||
self.callback = callback
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
if context := self.use_context(self.context):
|
||||
if not self.callback:
|
||||
context.update({self.key: x})
|
||||
else:
|
||||
self.callback(context[self.key], x)
|
||||
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
|
||||
|
||||
|
||||
class Lambda(Module):
|
||||
"""Lambda layer.
|
||||
|
||||
This layer wraps a [`Callable`][typing.Callable].
|
||||
|
||||
Note: When called, it will
|
||||
- Execute the [`Callable`][typing.Callable] with the given arguments
|
||||
- Return the output of the [`Callable`][typing.Callable])
|
||||
|
||||
Example:
|
||||
```py
|
||||
lambda_layer = fl.Lambda(lambda x: x + 1)
|
||||
|
||||
tensor = torch.tensor([1, 2, 3])
|
||||
output = lambda_layer(tensor)
|
||||
|
||||
expected_output = torch.tensor([2, 3, 4])
|
||||
assert torch.allclose(output, expected_output)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable[..., Any]) -> None:
|
||||
super().__init__()
|
||||
self.func = func
|
||||
|
||||
def forward(self, *args: Any) -> Any:
|
||||
return self.func(*args)
|
||||
|
||||
def __str__(self) -> str:
|
||||
func_name = getattr(self.func, "__name__", "partial_function")
|
||||
return f"Lambda({func_name}{str(inspect.signature(self.func))})"
|
||||
|
||||
|
||||
class Parallel(Chain):
|
||||
"""Parallel layer.
|
||||
|
||||
This layer calls its sub-modules in parallel with the same inputs, and returns a tuple of their outputs.
|
||||
|
||||
Example:
|
||||
```py
|
||||
parallel = fl.Parallel(
|
||||
fl.Linear(32, 64),
|
||||
fl.Identity(),
|
||||
fl.Linear(32, 128),
|
||||
)
|
||||
|
||||
tensor = torch.randn(2, 32)
|
||||
outputs = parallel(tensor)
|
||||
|
||||
assert len(outputs) == 3
|
||||
assert outputs[0].shape == (2, 64)
|
||||
assert torch.allclose(outputs[1], tensor)
|
||||
assert outputs[2].shape == (2, 128)
|
||||
```
|
||||
"""
|
||||
|
||||
_tag = "PAR"
|
||||
|
||||
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||
return tuple([self._call_layer(module, name, *args) for name, module in self._modules.items()])
|
||||
return tuple(
|
||||
[
|
||||
self._call_layer(
|
||||
module,
|
||||
name,
|
||||
*args, # same input for all sub-modules
|
||||
)
|
||||
for name, module in self._modules.items()
|
||||
],
|
||||
)
|
||||
|
||||
def _show_only_tag(self) -> bool:
|
||||
return self.__class__ == Parallel
|
||||
|
||||
|
||||
class Distribute(Chain):
|
||||
"""Distribute layer.
|
||||
|
||||
This layer calls its sub-modules in parallel with their respective input, and returns a tuple of their outputs.
|
||||
|
||||
Example:
|
||||
```py
|
||||
distribute = fl.Distribute(
|
||||
fl.Linear(32, 128),
|
||||
fl.Linear(64, 256),
|
||||
)
|
||||
|
||||
tensor1 = torch.randn(2, 32)
|
||||
tensor2 = torch.randn(4, 64)
|
||||
outputs = distribute(tensor1, tensor2)
|
||||
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].shape == (2, 128)
|
||||
assert outputs[1].shape == (4, 256)
|
||||
```
|
||||
"""
|
||||
|
||||
_tag = "DISTR"
|
||||
|
||||
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||
n, m = len(args), len(self._modules)
|
||||
assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})."
|
||||
return tuple([self._call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
|
||||
return tuple(
|
||||
[
|
||||
self._call_layer(
|
||||
module,
|
||||
name,
|
||||
arg, # each sub-module has its own input
|
||||
)
|
||||
for arg, (name, module) in zip(args, self._modules.items())
|
||||
]
|
||||
)
|
||||
|
||||
def _show_only_tag(self) -> bool:
|
||||
return self.__class__ == Distribute
|
||||
|
||||
|
||||
class Passthrough(Chain):
|
||||
"""Passthrough layer.
|
||||
|
||||
This layer call its sub-modules sequentially, and returns its original inputs,
|
||||
like an [`Identity`][refiners.fluxion.layers.Identity] layer.
|
||||
|
||||
Example:
|
||||
```py
|
||||
passthrough = fl.Passthrough(
|
||||
fl.Linear(32, 128),
|
||||
fl.ReLU(),
|
||||
fl.Linear(128, 128),
|
||||
)
|
||||
|
||||
tensor = torch.randn(2, 32)
|
||||
output = passthrough(tensor)
|
||||
|
||||
assert torch.allclose(output, tensor)
|
||||
```
|
||||
"""
|
||||
|
||||
_tag = "PASS"
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
|
@ -513,6 +820,24 @@ class Passthrough(Chain):
|
|||
|
||||
|
||||
class Sum(Chain):
|
||||
"""Summation layer.
|
||||
|
||||
This layer calls its sub-modules in parallel with the same inputs, and returns the sum of their outputs.
|
||||
|
||||
Example:
|
||||
```py
|
||||
summation = fl.Sum(
|
||||
fl.Multiply(scale=2, bias=1),
|
||||
fl.Multiply(scale=3, bias=0),
|
||||
)
|
||||
|
||||
tensor = torch.ones(1)
|
||||
output = summation(tensor)
|
||||
|
||||
assert torch.allclose(output, torch.tensor([6.0]))
|
||||
```
|
||||
"""
|
||||
|
||||
_tag = "SUM"
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
|
@ -529,6 +854,24 @@ class Sum(Chain):
|
|||
|
||||
|
||||
class Residual(Chain):
|
||||
"""Residual layer.
|
||||
|
||||
This layer calls its sub-modules sequentially, and adds the original input to the output.
|
||||
|
||||
Example:
|
||||
```py
|
||||
residual = fl.Residual(
|
||||
fl.Multiply(scale=10),
|
||||
)
|
||||
|
||||
tensor = torch.ones(2, 32)
|
||||
output = residual(tensor)
|
||||
|
||||
assert output.shape == (2, 32)
|
||||
assert torch.allclose(output, 10 * tensor + tensor)
|
||||
```
|
||||
"""
|
||||
|
||||
_tag = "RES"
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
|
@ -536,7 +879,101 @@ class Residual(Chain):
|
|||
return super().forward(*inputs) + inputs[0]
|
||||
|
||||
|
||||
class Concatenate(Chain):
|
||||
"""Concatenation layer.
|
||||
|
||||
This layer calls its sub-modules in parallel with the same inputs, and returns the concatenation of their outputs.
|
||||
|
||||
Example:
|
||||
```py
|
||||
concatenate = fl.Concatenate(
|
||||
fl.Linear(32, 128),
|
||||
fl.Linear(32, 128),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
tensor = torch.randn(2, 32)
|
||||
output = concatenate(tensor)
|
||||
|
||||
assert output.shape == (2, 256)
|
||||
```
|
||||
"""
|
||||
|
||||
_tag = "CAT"
|
||||
|
||||
def __init__(self, *modules: Module, dim: int = 0) -> None:
|
||||
super().__init__(*modules)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, *args: Any) -> Tensor:
|
||||
outputs = [module(*args) for module in self]
|
||||
return cat(
|
||||
[output for output in outputs if output is not None],
|
||||
dim=self.dim,
|
||||
)
|
||||
|
||||
def _show_only_tag(self) -> bool:
|
||||
return self.__class__ == Concatenate
|
||||
|
||||
|
||||
class Matmul(Chain):
|
||||
"""Matrix multiplication layer.
|
||||
|
||||
This layer returns the matrix multiplication of the outputs of its two sub-modules.
|
||||
|
||||
Example:
|
||||
```py
|
||||
matmul = fl.Matmul(
|
||||
fl.Identity(),
|
||||
fl.Multiply(scale=2),
|
||||
)
|
||||
|
||||
tensor = torch.randn(10, 10)
|
||||
output = matmul(tensor)
|
||||
|
||||
expected_output = tensor @ (2 * tensor)
|
||||
assert torch.allclose(output, expected_output)
|
||||
```
|
||||
"""
|
||||
|
||||
_tag = "MATMUL"
|
||||
|
||||
def __init__(self, input: Module, other: Module) -> None:
|
||||
super().__init__(
|
||||
input,
|
||||
other,
|
||||
)
|
||||
|
||||
def forward(self, *args: Tensor) -> Tensor:
|
||||
return torch.matmul(
|
||||
input=self[0](*args),
|
||||
other=self[1](*args),
|
||||
)
|
||||
|
||||
|
||||
class ReturnException(Exception):
|
||||
"""Exception raised when a Return module is encountered."""
|
||||
|
||||
def __init__(self, value: Tensor):
|
||||
self.value = value
|
||||
|
||||
|
||||
class Return(Module):
|
||||
"""Return layer.
|
||||
|
||||
This layer stops the execution of a Chain when encountered.
|
||||
"""
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
raise ReturnException(x)
|
||||
|
||||
|
||||
class Breakpoint(ContextModule):
|
||||
"""Breakpoint layer.
|
||||
|
||||
This layer pauses the execution when encountered, and opens a debugger.
|
||||
"""
|
||||
|
||||
def __init__(self, vscode: bool = True):
|
||||
super().__init__()
|
||||
self.vscode = vscode
|
||||
|
@ -549,31 +986,3 @@ class Breakpoint(ContextModule):
|
|||
else:
|
||||
breakpoint()
|
||||
return args[0] if len(args) == 1 else args
|
||||
|
||||
|
||||
class Concatenate(Chain):
|
||||
_tag = "CAT"
|
||||
|
||||
def __init__(self, *modules: Module, dim: int = 0) -> None:
|
||||
super().__init__(*modules)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, *args: Any) -> Tensor:
|
||||
outputs = [module(*args) for module in self]
|
||||
return cat([output for output in outputs if output is not None], dim=self.dim)
|
||||
|
||||
def _show_only_tag(self) -> bool:
|
||||
return self.__class__ == Concatenate
|
||||
|
||||
|
||||
class Matmul(Chain):
|
||||
_tag = "MATMUL"
|
||||
|
||||
def __init__(self, input: Module, other: Module) -> None:
|
||||
super().__init__(
|
||||
input,
|
||||
other,
|
||||
)
|
||||
|
||||
def forward(self, *args: Tensor) -> Tensor:
|
||||
return torch.matmul(input=self[0](*args), other=self[1](*args))
|
||||
|
|
Loading…
Reference in a new issue