mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
(doc/fluxion/chain) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
77b97b3c8e
commit
c7fd1496b5
|
@ -16,24 +16,14 @@ T = TypeVar("T", bound=Module)
|
||||||
TChain = TypeVar("TChain", bound="Chain") # because Self (PEP 673) is not in 3.10
|
TChain = TypeVar("TChain", bound="Chain") # because Self (PEP 673) is not in 3.10
|
||||||
|
|
||||||
|
|
||||||
class Lambda(Module):
|
|
||||||
"""Lambda is a wrapper around a callable object that allows it to be used as a PyTorch module."""
|
|
||||||
|
|
||||||
def __init__(self, func: Callable[..., Any]) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.func = func
|
|
||||||
|
|
||||||
def forward(self, *args: Any) -> Any:
|
|
||||||
return self.func(*args)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
func_name = getattr(self.func, "__name__", "partial_function")
|
|
||||||
return f"Lambda({func_name}{str(inspect.signature(self.func))})"
|
|
||||||
|
|
||||||
|
|
||||||
def generate_unique_names(
|
def generate_unique_names(
|
||||||
modules: tuple[Module, ...],
|
modules: tuple[Module, ...],
|
||||||
) -> dict[str, Module]:
|
) -> dict[str, Module]:
|
||||||
|
"""Generate unique names for each Module in a sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modules: The sequence of Modules to name.
|
||||||
|
"""
|
||||||
class_counts: dict[str, int] = {}
|
class_counts: dict[str, int] = {}
|
||||||
unique_names: list[tuple[str, Module]] = []
|
unique_names: list[tuple[str, Module]] = []
|
||||||
for module in modules:
|
for module in modules:
|
||||||
|
@ -48,69 +38,8 @@ def generate_unique_names(
|
||||||
return dict(unique_names)
|
return dict(unique_names)
|
||||||
|
|
||||||
|
|
||||||
class UseContext(ContextModule):
|
|
||||||
def __init__(self, context: str, key: str) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.context = context
|
|
||||||
self.key = key
|
|
||||||
self.func: Callable[[Any], Any] = lambda x: x
|
|
||||||
|
|
||||||
def __call__(self, *args: Any) -> Any:
|
|
||||||
context = self.use_context(self.context)
|
|
||||||
assert context, f"context {self.context} is unset"
|
|
||||||
value = context.get(self.key)
|
|
||||||
assert value is not None, f"context entry {self.context}.{self.key} is unset"
|
|
||||||
return self.func(value)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
|
|
||||||
|
|
||||||
def compose(self, func: Callable[[Any], Any]) -> "UseContext":
|
|
||||||
self.func = func
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class SetContext(ContextModule):
|
|
||||||
"""A Module that sets a context value when executed.
|
|
||||||
|
|
||||||
The context need to pre exist in the context provider.
|
|
||||||
#TODO Is there a way to create the context if it doesn't exist?
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, context: str, key: str, callback: Callable[[Any, Any], Any] | None = None) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.context = context
|
|
||||||
self.key = key
|
|
||||||
self.callback = callback
|
|
||||||
|
|
||||||
def __call__(self, x: Tensor) -> Tensor:
|
|
||||||
if context := self.use_context(self.context):
|
|
||||||
if not self.callback:
|
|
||||||
context.update({self.key: x})
|
|
||||||
else:
|
|
||||||
self.callback(context[self.key], x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
|
|
||||||
|
|
||||||
|
|
||||||
class ReturnException(Exception):
|
|
||||||
"""Exception raised when a Return module is encountered."""
|
|
||||||
|
|
||||||
def __init__(self, value: Tensor):
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
|
|
||||||
class Return(Module):
|
|
||||||
"""A Module that stops the execution of a Chain when encountered."""
|
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
|
||||||
raise ReturnException(x)
|
|
||||||
|
|
||||||
|
|
||||||
def structural_copy(m: T) -> T:
|
def structural_copy(m: T) -> T:
|
||||||
|
"""Helper function to copy a Module's tree, only if it is a ContextModule instance."""
|
||||||
return m.structural_copy() if isinstance(m, ContextModule) else m
|
return m.structural_copy() if isinstance(m, ContextModule) else m
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,6 +51,29 @@ class ChainError(RuntimeError):
|
||||||
|
|
||||||
|
|
||||||
class Chain(ContextModule):
|
class Chain(ContextModule):
|
||||||
|
"""Chain layer.
|
||||||
|
|
||||||
|
This layer is the main building block of Fluxion.
|
||||||
|
It is used to compose other layers in a sequential manner.
|
||||||
|
Similary to [`torch.nn.Sequential`][torch.nn.Sequential],
|
||||||
|
it calls each of its sub-layers in order, chaining their outputs as inputs to the next sublayer.
|
||||||
|
However, it also provides additional methods to manipulate its sub-layers and their context.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
chain = fl.Chain(
|
||||||
|
fl.Linear(32, 64),
|
||||||
|
fl.ReLU(),
|
||||||
|
fl.Linear(64, 128),
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor = torch.randn(2, 32)
|
||||||
|
output = chain(tensor)
|
||||||
|
|
||||||
|
assert output.shape == (2, 128)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
_modules: dict[str, Module]
|
_modules: dict[str, Module]
|
||||||
_provider: ContextProvider
|
_provider: ContextProvider
|
||||||
_tag = "CHAIN"
|
_tag = "CHAIN"
|
||||||
|
@ -165,12 +117,23 @@ class Chain(ContextModule):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider(self) -> ContextProvider:
|
def provider(self) -> ContextProvider:
|
||||||
|
"""The [`ContextProvider`][refiners.fluxion.context.ContextProvider] of the Chain."""
|
||||||
return self._provider
|
return self._provider
|
||||||
|
|
||||||
def init_context(self) -> Contexts:
|
def init_context(self) -> Contexts:
|
||||||
|
"""Initialize the context provider with some default values.
|
||||||
|
|
||||||
|
This method is called when the Chain is created, and when it is reset.
|
||||||
|
This method may be overridden by subclasses to provide default values for the context provider.
|
||||||
|
"""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _register_provider(self, context: Contexts | None = None) -> None:
|
def _register_provider(self, context: Contexts | None = None) -> None: # TODO: rename me ?
|
||||||
|
"""Recursively update the context provider to all sub-modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: The context to use to update the provider.
|
||||||
|
"""
|
||||||
if context:
|
if context:
|
||||||
self._provider.update_contexts(context)
|
self._provider.update_contexts(context)
|
||||||
|
|
||||||
|
@ -179,9 +142,16 @@ class Chain(ContextModule):
|
||||||
module._register_provider(context=self._provider.contexts)
|
module._register_provider(context=self._provider.contexts)
|
||||||
|
|
||||||
def _reset_context(self) -> None:
|
def _reset_context(self) -> None:
|
||||||
|
"""Reset the context provider to its initial state."""
|
||||||
self._register_provider(self.init_context())
|
self._register_provider(self.init_context())
|
||||||
|
|
||||||
def set_context(self, context: str, value: Any) -> None:
|
def set_context(self, context: str, value: Any) -> None:
|
||||||
|
"""Set a value in the context provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: The context to update.
|
||||||
|
value: The value to set.
|
||||||
|
"""
|
||||||
self._provider.set_context(context, value)
|
self._provider.set_context(context, value)
|
||||||
self._register_provider()
|
self._register_provider()
|
||||||
|
|
||||||
|
@ -315,18 +285,27 @@ class Chain(ContextModule):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> Device | None:
|
def device(self) -> Device | None:
|
||||||
|
"""The PyTorch device of the Chain's parameters."""
|
||||||
wm = self.find(WeightedModule)
|
wm = self.find(WeightedModule)
|
||||||
return None if wm is None else wm.device
|
return None if wm is None else wm.device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> DType | None:
|
def dtype(self) -> DType | None:
|
||||||
|
"""The PyTorch dtype of the Chain's parameters."""
|
||||||
wm = self.find(WeightedModule)
|
wm = self.find(WeightedModule)
|
||||||
return None if wm is None else wm.dtype
|
return None if wm is None else wm.dtype
|
||||||
|
|
||||||
def _walk(
|
def _walk(
|
||||||
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
self,
|
||||||
|
predicate: Callable[[Module, "Chain"], bool] | None = None,
|
||||||
|
recurse: bool = False,
|
||||||
) -> Iterator[tuple[Module, "Chain"]]:
|
) -> Iterator[tuple[Module, "Chain"]]:
|
||||||
|
"""Walk the Chain's sub-module tree and yield each module that matches the predicate.
|
||||||
|
|
||||||
|
The predicate is a (Module, Chain) -> bool function.
|
||||||
|
"""
|
||||||
if predicate is None:
|
if predicate is None:
|
||||||
|
# if no predicate is given, yield all modules
|
||||||
predicate = lambda _m, _p: True
|
predicate = lambda _m, _p: True
|
||||||
for module in self:
|
for module in self:
|
||||||
try:
|
try:
|
||||||
|
@ -342,35 +321,100 @@ class Chain(ContextModule):
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def walk(
|
def walk(
|
||||||
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
self,
|
||||||
|
predicate: Callable[[Module, "Chain"], bool] | None = None,
|
||||||
|
recurse: bool = False,
|
||||||
) -> Iterator[tuple[Module, "Chain"]]:
|
) -> Iterator[tuple[Module, "Chain"]]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]:
|
def walk(
|
||||||
|
self,
|
||||||
|
predicate: type[T],
|
||||||
|
recurse: bool = False,
|
||||||
|
) -> Iterator[tuple[T, "Chain"]]:
|
||||||
...
|
...
|
||||||
|
|
||||||
def walk(
|
def walk(
|
||||||
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
self,
|
||||||
|
predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None,
|
||||||
|
recurse: bool = False,
|
||||||
) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]:
|
) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]:
|
||||||
if isinstance(predicate, type):
|
"""Walk the Chain's sub-module tree and yield each module that matches the predicate.
|
||||||
return self._walk(lambda m, _: isinstance(m, predicate), recurse)
|
|
||||||
else:
|
|
||||||
return self._walk(predicate, recurse)
|
|
||||||
|
|
||||||
def layers(self, layer_type: type[T], recurse: bool = False) -> Iterator[T]:
|
Args:
|
||||||
|
predicate: The predicate to match.
|
||||||
|
recurse: Whether to recurse into sub-Chains.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Each module that matches the predicate.
|
||||||
|
"""
|
||||||
|
if isinstance(predicate, type):
|
||||||
|
# if the predicate is a Module type
|
||||||
|
# build a predicate function that matches the type
|
||||||
|
return self._walk(
|
||||||
|
predicate=lambda m, _: isinstance(m, predicate),
|
||||||
|
recurse=recurse,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._walk(
|
||||||
|
predicate=predicate,
|
||||||
|
recurse=recurse,
|
||||||
|
)
|
||||||
|
|
||||||
|
def layers(
|
||||||
|
self,
|
||||||
|
layer_type: type[T],
|
||||||
|
recurse: bool = False,
|
||||||
|
) -> Iterator[T]:
|
||||||
|
"""Walk the Chain's sub-module tree and yield each layer of the given type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_type: The type of layer to yield.
|
||||||
|
recurse: Whether to recurse into sub-Chains.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Each module of the given layer_type.
|
||||||
|
"""
|
||||||
for module, _ in self.walk(layer_type, recurse):
|
for module, _ in self.walk(layer_type, recurse):
|
||||||
yield module
|
yield module
|
||||||
|
|
||||||
def find(self, layer_type: type[T]) -> T | None:
|
def find(self, layer_type: type[T]) -> T | None:
|
||||||
|
"""Walk the Chain's sub-module tree and return the first layer of the given type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_type: The type of layer to find.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The first module of the given layer_type, or None if it doesn't exist.
|
||||||
|
"""
|
||||||
return next(self.layers(layer_type=layer_type), None)
|
return next(self.layers(layer_type=layer_type), None)
|
||||||
|
|
||||||
def ensure_find(self, layer_type: type[T]) -> T:
|
def ensure_find(self, layer_type: type[T]) -> T:
|
||||||
|
"""Walk the Chain's sub-module tree and return the first layer of the given type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_type: The type of layer to find.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The first module of the given layer_type.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the module doesn't exist.
|
||||||
|
"""
|
||||||
r = self.find(layer_type)
|
r = self.find(layer_type)
|
||||||
assert r is not None, f"could not find {layer_type} in {self}"
|
assert r is not None, f"could not find {layer_type} in {self}"
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def find_parent(self, module: Module) -> "Chain | None":
|
def find_parent(self, module: Module) -> "Chain | None":
|
||||||
|
"""Walk the Chain's sub-module tree and return the parent of the given module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: The module whose parent to find.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The parent of the given module, or None if it doesn't exist.
|
||||||
|
"""
|
||||||
if module in self: # avoid DFS-crawling the whole tree
|
if module in self: # avoid DFS-crawling the whole tree
|
||||||
return self
|
return self
|
||||||
for _, parent in self.walk(lambda m, _: m == module):
|
for _, parent in self.walk(lambda m, _: m == module):
|
||||||
|
@ -378,11 +422,31 @@ class Chain(ContextModule):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def ensure_find_parent(self, module: Module) -> "Chain":
|
def ensure_find_parent(self, module: Module) -> "Chain":
|
||||||
|
"""Walk the Chain's sub-module tree and return the parent of the given module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: The module whose parent to find.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The parent of the given module.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the module doesn't exist.
|
||||||
|
"""
|
||||||
r = self.find_parent(module)
|
r = self.find_parent(module)
|
||||||
assert r is not None, f"could not find {module} in {self}"
|
assert r is not None, f"could not find {module} in {self}"
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def insert(self, index: int, module: Module) -> None:
|
def insert(self, index: int, module: Module) -> None:
|
||||||
|
"""Insert a new module in the chain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: The index at which to insert the module.
|
||||||
|
module: The module to insert.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IndexError: If the index is out of range.
|
||||||
|
"""
|
||||||
if index < 0:
|
if index < 0:
|
||||||
index = max(0, len(self._modules) + index + 1)
|
index = max(0, len(self._modules) + index + 1)
|
||||||
modules = list(self)
|
modules = list(self)
|
||||||
|
@ -393,6 +457,15 @@ class Chain(ContextModule):
|
||||||
self._register_provider()
|
self._register_provider()
|
||||||
|
|
||||||
def insert_before_type(self, module_type: type[Module], new_module: Module) -> None:
|
def insert_before_type(self, module_type: type[Module], new_module: Module) -> None:
|
||||||
|
"""Insert a new module in the chain, right before the first module of the given type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_type: The type of module to insert before.
|
||||||
|
new_module: The module to insert.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no module of the given type exists in the chain.
|
||||||
|
"""
|
||||||
for i, module in enumerate(self):
|
for i, module in enumerate(self):
|
||||||
if isinstance(module, module_type):
|
if isinstance(module, module_type):
|
||||||
self.insert(i, new_module)
|
self.insert(i, new_module)
|
||||||
|
@ -400,6 +473,15 @@ class Chain(ContextModule):
|
||||||
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
|
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
|
||||||
|
|
||||||
def insert_after_type(self, module_type: type[Module], new_module: Module) -> None:
|
def insert_after_type(self, module_type: type[Module], new_module: Module) -> None:
|
||||||
|
"""Insert a new module in the chain, right after the first module of the given type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_type: The type of module to insert after.
|
||||||
|
new_module: The module to insert.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no module of the given type exists in the chain.
|
||||||
|
"""
|
||||||
for i, module in enumerate(self):
|
for i, module in enumerate(self):
|
||||||
if isinstance(module, module_type):
|
if isinstance(module, module_type):
|
||||||
self.insert(i + 1, new_module)
|
self.insert(i + 1, new_module)
|
||||||
|
@ -407,9 +489,25 @@ class Chain(ContextModule):
|
||||||
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
|
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
|
||||||
|
|
||||||
def append(self, module: Module) -> None:
|
def append(self, module: Module) -> None:
|
||||||
|
"""Append a new module to the chain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: The module to append.
|
||||||
|
"""
|
||||||
self.insert(-1, module)
|
self.insert(-1, module)
|
||||||
|
|
||||||
def pop(self, index: int = -1) -> Module | tuple[Module]:
|
def pop(self, index: int = -1) -> Module | tuple[Module]:
|
||||||
|
"""Pop a module from the chain at the given index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: The index of the module to pop.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The popped module.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IndexError: If the index is out of range.
|
||||||
|
"""
|
||||||
modules = list(self)
|
modules = list(self)
|
||||||
if index < 0:
|
if index < 0:
|
||||||
index = len(modules) + index
|
index = len(modules) + index
|
||||||
|
@ -422,7 +520,14 @@ class Chain(ContextModule):
|
||||||
return removed_module
|
return removed_module
|
||||||
|
|
||||||
def remove(self, module: Module) -> None:
|
def remove(self, module: Module) -> None:
|
||||||
"""Remove a module from the chain."""
|
"""Remove a module from the chain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: The module to remove.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the module is not in the chain.
|
||||||
|
"""
|
||||||
modules = list(self)
|
modules = list(self)
|
||||||
try:
|
try:
|
||||||
modules.remove(module)
|
modules.remove(module)
|
||||||
|
@ -438,7 +543,17 @@ class Chain(ContextModule):
|
||||||
new_module: Module,
|
new_module: Module,
|
||||||
old_module_parent: "Chain | None" = None,
|
old_module_parent: "Chain | None" = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Replace a module in the chain with a new module."""
|
"""Replace a module in the chain with a new module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
old_module: The module to replace.
|
||||||
|
new_module: The module to replace with.
|
||||||
|
old_module_parent: The parent of the old module.
|
||||||
|
If None, the old module is orphanized.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the module is not in the chain.
|
||||||
|
"""
|
||||||
modules = list(self)
|
modules = list(self)
|
||||||
try:
|
try:
|
||||||
modules[modules.index(old_module)] = new_module
|
modules[modules.index(old_module)] = new_module
|
||||||
|
@ -479,29 +594,221 @@ class Chain(ContextModule):
|
||||||
return self.__class__ == Chain
|
return self.__class__ == Chain
|
||||||
|
|
||||||
|
|
||||||
|
class UseContext(ContextModule):
|
||||||
|
"""UseContext layer.
|
||||||
|
|
||||||
|
This layer reads from the [`ContextProvider`][refiners.fluxion.context.ContextProvider]
|
||||||
|
of its parent [`Chain`][refiners.fluxion.layers.chain.Chain].
|
||||||
|
|
||||||
|
Note: When called, it will
|
||||||
|
- Retrieve a value from the context using the given key
|
||||||
|
- Transform the value with the given function (optional)
|
||||||
|
- Return the value
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, context: str, key: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.context = context
|
||||||
|
self.key = key
|
||||||
|
self.func: Callable[[Any], Any] = lambda x: x
|
||||||
|
|
||||||
|
def __call__(self, *args: Any) -> Any:
|
||||||
|
context = self.use_context(self.context)
|
||||||
|
assert context, f"context {self.context} is unset"
|
||||||
|
value = context.get(self.key)
|
||||||
|
assert value is not None, f"context entry {self.context}.{self.key} is unset"
|
||||||
|
return self.func(value)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
|
||||||
|
|
||||||
|
def compose(self, func: Callable[[Any], Any]) -> "UseContext":
|
||||||
|
self.func = func
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class SetContext(ContextModule):
|
||||||
|
"""SetContext layer.
|
||||||
|
|
||||||
|
This layer writes to the [`ContextProvider`][refiners.fluxion.context.ContextProvider]
|
||||||
|
of its parent [`Chain`][refiners.fluxion.layers.chain.Chain].
|
||||||
|
|
||||||
|
Note: When called (without a callback), it will
|
||||||
|
- Update the context with the given key and the input value
|
||||||
|
- Return the input value
|
||||||
|
|
||||||
|
Note: When called (with a callback), it will
|
||||||
|
- Call the callback with the current value and the input value
|
||||||
|
(the callback may update the context with a new value, or not)
|
||||||
|
- Return the input value
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
The context needs to already exist in the [`ContextProvider`][refiners.fluxion.context.ContextProvider]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: Create the context if it doesn't exist
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context: str,
|
||||||
|
key: str,
|
||||||
|
callback: Callable[[Any, Any], Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.context = context
|
||||||
|
self.key = key
|
||||||
|
self.callback = callback
|
||||||
|
|
||||||
|
def __call__(self, x: Tensor) -> Tensor:
|
||||||
|
if context := self.use_context(self.context):
|
||||||
|
if not self.callback:
|
||||||
|
context.update({self.key: x})
|
||||||
|
else:
|
||||||
|
self.callback(context[self.key], x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
|
||||||
|
|
||||||
|
|
||||||
|
class Lambda(Module):
|
||||||
|
"""Lambda layer.
|
||||||
|
|
||||||
|
This layer wraps a [`Callable`][typing.Callable].
|
||||||
|
|
||||||
|
Note: When called, it will
|
||||||
|
- Execute the [`Callable`][typing.Callable] with the given arguments
|
||||||
|
- Return the output of the [`Callable`][typing.Callable])
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
lambda_layer = fl.Lambda(lambda x: x + 1)
|
||||||
|
|
||||||
|
tensor = torch.tensor([1, 2, 3])
|
||||||
|
output = lambda_layer(tensor)
|
||||||
|
|
||||||
|
expected_output = torch.tensor([2, 3, 4])
|
||||||
|
assert torch.allclose(output, expected_output)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, func: Callable[..., Any]) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def forward(self, *args: Any) -> Any:
|
||||||
|
return self.func(*args)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
func_name = getattr(self.func, "__name__", "partial_function")
|
||||||
|
return f"Lambda({func_name}{str(inspect.signature(self.func))})"
|
||||||
|
|
||||||
|
|
||||||
class Parallel(Chain):
|
class Parallel(Chain):
|
||||||
|
"""Parallel layer.
|
||||||
|
|
||||||
|
This layer calls its sub-modules in parallel with the same inputs, and returns a tuple of their outputs.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
parallel = fl.Parallel(
|
||||||
|
fl.Linear(32, 64),
|
||||||
|
fl.Identity(),
|
||||||
|
fl.Linear(32, 128),
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor = torch.randn(2, 32)
|
||||||
|
outputs = parallel(tensor)
|
||||||
|
|
||||||
|
assert len(outputs) == 3
|
||||||
|
assert outputs[0].shape == (2, 64)
|
||||||
|
assert torch.allclose(outputs[1], tensor)
|
||||||
|
assert outputs[2].shape == (2, 128)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
_tag = "PAR"
|
_tag = "PAR"
|
||||||
|
|
||||||
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||||
return tuple([self._call_layer(module, name, *args) for name, module in self._modules.items()])
|
return tuple(
|
||||||
|
[
|
||||||
|
self._call_layer(
|
||||||
|
module,
|
||||||
|
name,
|
||||||
|
*args, # same input for all sub-modules
|
||||||
|
)
|
||||||
|
for name, module in self._modules.items()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
def _show_only_tag(self) -> bool:
|
def _show_only_tag(self) -> bool:
|
||||||
return self.__class__ == Parallel
|
return self.__class__ == Parallel
|
||||||
|
|
||||||
|
|
||||||
class Distribute(Chain):
|
class Distribute(Chain):
|
||||||
|
"""Distribute layer.
|
||||||
|
|
||||||
|
This layer calls its sub-modules in parallel with their respective input, and returns a tuple of their outputs.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
distribute = fl.Distribute(
|
||||||
|
fl.Linear(32, 128),
|
||||||
|
fl.Linear(64, 256),
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor1 = torch.randn(2, 32)
|
||||||
|
tensor2 = torch.randn(4, 64)
|
||||||
|
outputs = distribute(tensor1, tensor2)
|
||||||
|
|
||||||
|
assert len(outputs) == 2
|
||||||
|
assert outputs[0].shape == (2, 128)
|
||||||
|
assert outputs[1].shape == (4, 256)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
_tag = "DISTR"
|
_tag = "DISTR"
|
||||||
|
|
||||||
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||||
n, m = len(args), len(self._modules)
|
n, m = len(args), len(self._modules)
|
||||||
assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})."
|
assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})."
|
||||||
return tuple([self._call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
|
return tuple(
|
||||||
|
[
|
||||||
|
self._call_layer(
|
||||||
|
module,
|
||||||
|
name,
|
||||||
|
arg, # each sub-module has its own input
|
||||||
|
)
|
||||||
|
for arg, (name, module) in zip(args, self._modules.items())
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def _show_only_tag(self) -> bool:
|
def _show_only_tag(self) -> bool:
|
||||||
return self.__class__ == Distribute
|
return self.__class__ == Distribute
|
||||||
|
|
||||||
|
|
||||||
class Passthrough(Chain):
|
class Passthrough(Chain):
|
||||||
|
"""Passthrough layer.
|
||||||
|
|
||||||
|
This layer call its sub-modules sequentially, and returns its original inputs,
|
||||||
|
like an [`Identity`][refiners.fluxion.layers.Identity] layer.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
passthrough = fl.Passthrough(
|
||||||
|
fl.Linear(32, 128),
|
||||||
|
fl.ReLU(),
|
||||||
|
fl.Linear(128, 128),
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor = torch.randn(2, 32)
|
||||||
|
output = passthrough(tensor)
|
||||||
|
|
||||||
|
assert torch.allclose(output, tensor)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
_tag = "PASS"
|
_tag = "PASS"
|
||||||
|
|
||||||
def forward(self, *inputs: Any) -> Any:
|
def forward(self, *inputs: Any) -> Any:
|
||||||
|
@ -513,6 +820,24 @@ class Passthrough(Chain):
|
||||||
|
|
||||||
|
|
||||||
class Sum(Chain):
|
class Sum(Chain):
|
||||||
|
"""Summation layer.
|
||||||
|
|
||||||
|
This layer calls its sub-modules in parallel with the same inputs, and returns the sum of their outputs.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
summation = fl.Sum(
|
||||||
|
fl.Multiply(scale=2, bias=1),
|
||||||
|
fl.Multiply(scale=3, bias=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor = torch.ones(1)
|
||||||
|
output = summation(tensor)
|
||||||
|
|
||||||
|
assert torch.allclose(output, torch.tensor([6.0]))
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
_tag = "SUM"
|
_tag = "SUM"
|
||||||
|
|
||||||
def forward(self, *inputs: Any) -> Any:
|
def forward(self, *inputs: Any) -> Any:
|
||||||
|
@ -529,6 +854,24 @@ class Sum(Chain):
|
||||||
|
|
||||||
|
|
||||||
class Residual(Chain):
|
class Residual(Chain):
|
||||||
|
"""Residual layer.
|
||||||
|
|
||||||
|
This layer calls its sub-modules sequentially, and adds the original input to the output.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
residual = fl.Residual(
|
||||||
|
fl.Multiply(scale=10),
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor = torch.ones(2, 32)
|
||||||
|
output = residual(tensor)
|
||||||
|
|
||||||
|
assert output.shape == (2, 32)
|
||||||
|
assert torch.allclose(output, 10 * tensor + tensor)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
_tag = "RES"
|
_tag = "RES"
|
||||||
|
|
||||||
def forward(self, *inputs: Any) -> Any:
|
def forward(self, *inputs: Any) -> Any:
|
||||||
|
@ -536,7 +879,101 @@ class Residual(Chain):
|
||||||
return super().forward(*inputs) + inputs[0]
|
return super().forward(*inputs) + inputs[0]
|
||||||
|
|
||||||
|
|
||||||
|
class Concatenate(Chain):
|
||||||
|
"""Concatenation layer.
|
||||||
|
|
||||||
|
This layer calls its sub-modules in parallel with the same inputs, and returns the concatenation of their outputs.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
concatenate = fl.Concatenate(
|
||||||
|
fl.Linear(32, 128),
|
||||||
|
fl.Linear(32, 128),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor = torch.randn(2, 32)
|
||||||
|
output = concatenate(tensor)
|
||||||
|
|
||||||
|
assert output.shape == (2, 256)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
_tag = "CAT"
|
||||||
|
|
||||||
|
def __init__(self, *modules: Module, dim: int = 0) -> None:
|
||||||
|
super().__init__(*modules)
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, *args: Any) -> Tensor:
|
||||||
|
outputs = [module(*args) for module in self]
|
||||||
|
return cat(
|
||||||
|
[output for output in outputs if output is not None],
|
||||||
|
dim=self.dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _show_only_tag(self) -> bool:
|
||||||
|
return self.__class__ == Concatenate
|
||||||
|
|
||||||
|
|
||||||
|
class Matmul(Chain):
|
||||||
|
"""Matrix multiplication layer.
|
||||||
|
|
||||||
|
This layer returns the matrix multiplication of the outputs of its two sub-modules.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
matmul = fl.Matmul(
|
||||||
|
fl.Identity(),
|
||||||
|
fl.Multiply(scale=2),
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor = torch.randn(10, 10)
|
||||||
|
output = matmul(tensor)
|
||||||
|
|
||||||
|
expected_output = tensor @ (2 * tensor)
|
||||||
|
assert torch.allclose(output, expected_output)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
_tag = "MATMUL"
|
||||||
|
|
||||||
|
def __init__(self, input: Module, other: Module) -> None:
|
||||||
|
super().__init__(
|
||||||
|
input,
|
||||||
|
other,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, *args: Tensor) -> Tensor:
|
||||||
|
return torch.matmul(
|
||||||
|
input=self[0](*args),
|
||||||
|
other=self[1](*args),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnException(Exception):
|
||||||
|
"""Exception raised when a Return module is encountered."""
|
||||||
|
|
||||||
|
def __init__(self, value: Tensor):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
|
class Return(Module):
|
||||||
|
"""Return layer.
|
||||||
|
|
||||||
|
This layer stops the execution of a Chain when encountered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
raise ReturnException(x)
|
||||||
|
|
||||||
|
|
||||||
class Breakpoint(ContextModule):
|
class Breakpoint(ContextModule):
|
||||||
|
"""Breakpoint layer.
|
||||||
|
|
||||||
|
This layer pauses the execution when encountered, and opens a debugger.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, vscode: bool = True):
|
def __init__(self, vscode: bool = True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vscode = vscode
|
self.vscode = vscode
|
||||||
|
@ -549,31 +986,3 @@ class Breakpoint(ContextModule):
|
||||||
else:
|
else:
|
||||||
breakpoint()
|
breakpoint()
|
||||||
return args[0] if len(args) == 1 else args
|
return args[0] if len(args) == 1 else args
|
||||||
|
|
||||||
|
|
||||||
class Concatenate(Chain):
|
|
||||||
_tag = "CAT"
|
|
||||||
|
|
||||||
def __init__(self, *modules: Module, dim: int = 0) -> None:
|
|
||||||
super().__init__(*modules)
|
|
||||||
self.dim = dim
|
|
||||||
|
|
||||||
def forward(self, *args: Any) -> Tensor:
|
|
||||||
outputs = [module(*args) for module in self]
|
|
||||||
return cat([output for output in outputs if output is not None], dim=self.dim)
|
|
||||||
|
|
||||||
def _show_only_tag(self) -> bool:
|
|
||||||
return self.__class__ == Concatenate
|
|
||||||
|
|
||||||
|
|
||||||
class Matmul(Chain):
|
|
||||||
_tag = "MATMUL"
|
|
||||||
|
|
||||||
def __init__(self, input: Module, other: Module) -> None:
|
|
||||||
super().__init__(
|
|
||||||
input,
|
|
||||||
other,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, *args: Tensor) -> Tensor:
|
|
||||||
return torch.matmul(input=self[0](*args), other=self[1](*args))
|
|
||||||
|
|
Loading…
Reference in a new issue