(doc/fluxion/chain) add/convert docstrings to mkdocstrings format

This commit is contained in:
Laurent 2024-02-01 22:09:04 +00:00 committed by Laureηt
parent 77b97b3c8e
commit c7fd1496b5

View file

@ -16,24 +16,14 @@ T = TypeVar("T", bound=Module)
TChain = TypeVar("TChain", bound="Chain") # because Self (PEP 673) is not in 3.10
class Lambda(Module):
"""Lambda is a wrapper around a callable object that allows it to be used as a PyTorch module."""
def __init__(self, func: Callable[..., Any]) -> None:
super().__init__()
self.func = func
def forward(self, *args: Any) -> Any:
return self.func(*args)
def __str__(self) -> str:
func_name = getattr(self.func, "__name__", "partial_function")
return f"Lambda({func_name}{str(inspect.signature(self.func))})"
def generate_unique_names(
modules: tuple[Module, ...],
) -> dict[str, Module]:
"""Generate unique names for each Module in a sequence.
Args:
modules: The sequence of Modules to name.
"""
class_counts: dict[str, int] = {}
unique_names: list[tuple[str, Module]] = []
for module in modules:
@ -48,69 +38,8 @@ def generate_unique_names(
return dict(unique_names)
class UseContext(ContextModule):
def __init__(self, context: str, key: str) -> None:
super().__init__()
self.context = context
self.key = key
self.func: Callable[[Any], Any] = lambda x: x
def __call__(self, *args: Any) -> Any:
context = self.use_context(self.context)
assert context, f"context {self.context} is unset"
value = context.get(self.key)
assert value is not None, f"context entry {self.context}.{self.key} is unset"
return self.func(value)
def __repr__(self):
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
def compose(self, func: Callable[[Any], Any]) -> "UseContext":
self.func = func
return self
class SetContext(ContextModule):
"""A Module that sets a context value when executed.
The context need to pre exist in the context provider.
#TODO Is there a way to create the context if it doesn't exist?
"""
def __init__(self, context: str, key: str, callback: Callable[[Any, Any], Any] | None = None) -> None:
super().__init__()
self.context = context
self.key = key
self.callback = callback
def __call__(self, x: Tensor) -> Tensor:
if context := self.use_context(self.context):
if not self.callback:
context.update({self.key: x})
else:
self.callback(context[self.key], x)
return x
def __repr__(self):
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
class ReturnException(Exception):
"""Exception raised when a Return module is encountered."""
def __init__(self, value: Tensor):
self.value = value
class Return(Module):
"""A Module that stops the execution of a Chain when encountered."""
def forward(self, x: Tensor):
raise ReturnException(x)
def structural_copy(m: T) -> T:
"""Helper function to copy a Module's tree, only if it is a ContextModule instance."""
return m.structural_copy() if isinstance(m, ContextModule) else m
@ -122,6 +51,29 @@ class ChainError(RuntimeError):
class Chain(ContextModule):
"""Chain layer.
This layer is the main building block of Fluxion.
It is used to compose other layers in a sequential manner.
Similary to [`torch.nn.Sequential`][torch.nn.Sequential],
it calls each of its sub-layers in order, chaining their outputs as inputs to the next sublayer.
However, it also provides additional methods to manipulate its sub-layers and their context.
Example:
```py
chain = fl.Chain(
fl.Linear(32, 64),
fl.ReLU(),
fl.Linear(64, 128),
)
tensor = torch.randn(2, 32)
output = chain(tensor)
assert output.shape == (2, 128)
```
"""
_modules: dict[str, Module]
_provider: ContextProvider
_tag = "CHAIN"
@ -165,12 +117,23 @@ class Chain(ContextModule):
@property
def provider(self) -> ContextProvider:
"""The [`ContextProvider`][refiners.fluxion.context.ContextProvider] of the Chain."""
return self._provider
def init_context(self) -> Contexts:
"""Initialize the context provider with some default values.
This method is called when the Chain is created, and when it is reset.
This method may be overridden by subclasses to provide default values for the context provider.
"""
return {}
def _register_provider(self, context: Contexts | None = None) -> None:
def _register_provider(self, context: Contexts | None = None) -> None: # TODO: rename me ?
"""Recursively update the context provider to all sub-modules.
Args:
context: The context to use to update the provider.
"""
if context:
self._provider.update_contexts(context)
@ -179,9 +142,16 @@ class Chain(ContextModule):
module._register_provider(context=self._provider.contexts)
def _reset_context(self) -> None:
"""Reset the context provider to its initial state."""
self._register_provider(self.init_context())
def set_context(self, context: str, value: Any) -> None:
"""Set a value in the context provider.
Args:
context: The context to update.
value: The value to set.
"""
self._provider.set_context(context, value)
self._register_provider()
@ -315,18 +285,27 @@ class Chain(ContextModule):
@property
def device(self) -> Device | None:
"""The PyTorch device of the Chain's parameters."""
wm = self.find(WeightedModule)
return None if wm is None else wm.device
@property
def dtype(self) -> DType | None:
"""The PyTorch dtype of the Chain's parameters."""
wm = self.find(WeightedModule)
return None if wm is None else wm.dtype
def _walk(
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
self,
predicate: Callable[[Module, "Chain"], bool] | None = None,
recurse: bool = False,
) -> Iterator[tuple[Module, "Chain"]]:
"""Walk the Chain's sub-module tree and yield each module that matches the predicate.
The predicate is a (Module, Chain) -> bool function.
"""
if predicate is None:
# if no predicate is given, yield all modules
predicate = lambda _m, _p: True
for module in self:
try:
@ -342,35 +321,100 @@ class Chain(ContextModule):
@overload
def walk(
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
self,
predicate: Callable[[Module, "Chain"], bool] | None = None,
recurse: bool = False,
) -> Iterator[tuple[Module, "Chain"]]:
...
@overload
def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]:
def walk(
self,
predicate: type[T],
recurse: bool = False,
) -> Iterator[tuple[T, "Chain"]]:
...
def walk(
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
self,
predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None,
recurse: bool = False,
) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]:
if isinstance(predicate, type):
return self._walk(lambda m, _: isinstance(m, predicate), recurse)
else:
return self._walk(predicate, recurse)
"""Walk the Chain's sub-module tree and yield each module that matches the predicate.
def layers(self, layer_type: type[T], recurse: bool = False) -> Iterator[T]:
Args:
predicate: The predicate to match.
recurse: Whether to recurse into sub-Chains.
Yields:
Each module that matches the predicate.
"""
if isinstance(predicate, type):
# if the predicate is a Module type
# build a predicate function that matches the type
return self._walk(
predicate=lambda m, _: isinstance(m, predicate),
recurse=recurse,
)
else:
return self._walk(
predicate=predicate,
recurse=recurse,
)
def layers(
self,
layer_type: type[T],
recurse: bool = False,
) -> Iterator[T]:
"""Walk the Chain's sub-module tree and yield each layer of the given type.
Args:
layer_type: The type of layer to yield.
recurse: Whether to recurse into sub-Chains.
Yields:
Each module of the given layer_type.
"""
for module, _ in self.walk(layer_type, recurse):
yield module
def find(self, layer_type: type[T]) -> T | None:
"""Walk the Chain's sub-module tree and return the first layer of the given type.
Args:
layer_type: The type of layer to find.
Returns:
The first module of the given layer_type, or None if it doesn't exist.
"""
return next(self.layers(layer_type=layer_type), None)
def ensure_find(self, layer_type: type[T]) -> T:
"""Walk the Chain's sub-module tree and return the first layer of the given type.
Args:
layer_type: The type of layer to find.
Returns:
The first module of the given layer_type.
Raises:
AssertionError: If the module doesn't exist.
"""
r = self.find(layer_type)
assert r is not None, f"could not find {layer_type} in {self}"
return r
def find_parent(self, module: Module) -> "Chain | None":
"""Walk the Chain's sub-module tree and return the parent of the given module.
Args:
module: The module whose parent to find.
Returns:
The parent of the given module, or None if it doesn't exist.
"""
if module in self: # avoid DFS-crawling the whole tree
return self
for _, parent in self.walk(lambda m, _: m == module):
@ -378,11 +422,31 @@ class Chain(ContextModule):
return None
def ensure_find_parent(self, module: Module) -> "Chain":
"""Walk the Chain's sub-module tree and return the parent of the given module.
Args:
module: The module whose parent to find.
Returns:
The parent of the given module.
Raises:
AssertionError: If the module doesn't exist.
"""
r = self.find_parent(module)
assert r is not None, f"could not find {module} in {self}"
return r
def insert(self, index: int, module: Module) -> None:
"""Insert a new module in the chain.
Args:
index: The index at which to insert the module.
module: The module to insert.
Raises:
IndexError: If the index is out of range.
"""
if index < 0:
index = max(0, len(self._modules) + index + 1)
modules = list(self)
@ -393,6 +457,15 @@ class Chain(ContextModule):
self._register_provider()
def insert_before_type(self, module_type: type[Module], new_module: Module) -> None:
"""Insert a new module in the chain, right before the first module of the given type.
Args:
module_type: The type of module to insert before.
new_module: The module to insert.
Raises:
ValueError: If no module of the given type exists in the chain.
"""
for i, module in enumerate(self):
if isinstance(module, module_type):
self.insert(i, new_module)
@ -400,6 +473,15 @@ class Chain(ContextModule):
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
def insert_after_type(self, module_type: type[Module], new_module: Module) -> None:
"""Insert a new module in the chain, right after the first module of the given type.
Args:
module_type: The type of module to insert after.
new_module: The module to insert.
Raises:
ValueError: If no module of the given type exists in the chain.
"""
for i, module in enumerate(self):
if isinstance(module, module_type):
self.insert(i + 1, new_module)
@ -407,9 +489,25 @@ class Chain(ContextModule):
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
def append(self, module: Module) -> None:
"""Append a new module to the chain.
Args:
module: The module to append.
"""
self.insert(-1, module)
def pop(self, index: int = -1) -> Module | tuple[Module]:
"""Pop a module from the chain at the given index.
Args:
index: The index of the module to pop.
Returns:
The popped module.
Raises:
IndexError: If the index is out of range.
"""
modules = list(self)
if index < 0:
index = len(modules) + index
@ -422,7 +520,14 @@ class Chain(ContextModule):
return removed_module
def remove(self, module: Module) -> None:
"""Remove a module from the chain."""
"""Remove a module from the chain.
Args:
module: The module to remove.
Raises:
ValueError: If the module is not in the chain.
"""
modules = list(self)
try:
modules.remove(module)
@ -438,7 +543,17 @@ class Chain(ContextModule):
new_module: Module,
old_module_parent: "Chain | None" = None,
) -> None:
"""Replace a module in the chain with a new module."""
"""Replace a module in the chain with a new module.
Args:
old_module: The module to replace.
new_module: The module to replace with.
old_module_parent: The parent of the old module.
If None, the old module is orphanized.
Raises:
ValueError: If the module is not in the chain.
"""
modules = list(self)
try:
modules[modules.index(old_module)] = new_module
@ -479,29 +594,221 @@ class Chain(ContextModule):
return self.__class__ == Chain
class UseContext(ContextModule):
"""UseContext layer.
This layer reads from the [`ContextProvider`][refiners.fluxion.context.ContextProvider]
of its parent [`Chain`][refiners.fluxion.layers.chain.Chain].
Note: When called, it will
- Retrieve a value from the context using the given key
- Transform the value with the given function (optional)
- Return the value
"""
def __init__(self, context: str, key: str) -> None:
super().__init__()
self.context = context
self.key = key
self.func: Callable[[Any], Any] = lambda x: x
def __call__(self, *args: Any) -> Any:
context = self.use_context(self.context)
assert context, f"context {self.context} is unset"
value = context.get(self.key)
assert value is not None, f"context entry {self.context}.{self.key} is unset"
return self.func(value)
def __repr__(self):
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
def compose(self, func: Callable[[Any], Any]) -> "UseContext":
self.func = func
return self
class SetContext(ContextModule):
"""SetContext layer.
This layer writes to the [`ContextProvider`][refiners.fluxion.context.ContextProvider]
of its parent [`Chain`][refiners.fluxion.layers.chain.Chain].
Note: When called (without a callback), it will
- Update the context with the given key and the input value
- Return the input value
Note: When called (with a callback), it will
- Call the callback with the current value and the input value
(the callback may update the context with a new value, or not)
- Return the input value
Warning:
The context needs to already exist in the [`ContextProvider`][refiners.fluxion.context.ContextProvider]
"""
# TODO: Create the context if it doesn't exist
def __init__(
self,
context: str,
key: str,
callback: Callable[[Any, Any], Any] | None = None,
) -> None:
super().__init__()
self.context = context
self.key = key
self.callback = callback
def __call__(self, x: Tensor) -> Tensor:
if context := self.use_context(self.context):
if not self.callback:
context.update({self.key: x})
else:
self.callback(context[self.key], x)
return x
def __repr__(self):
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
class Lambda(Module):
"""Lambda layer.
This layer wraps a [`Callable`][typing.Callable].
Note: When called, it will
- Execute the [`Callable`][typing.Callable] with the given arguments
- Return the output of the [`Callable`][typing.Callable])
Example:
```py
lambda_layer = fl.Lambda(lambda x: x + 1)
tensor = torch.tensor([1, 2, 3])
output = lambda_layer(tensor)
expected_output = torch.tensor([2, 3, 4])
assert torch.allclose(output, expected_output)
```
"""
def __init__(self, func: Callable[..., Any]) -> None:
super().__init__()
self.func = func
def forward(self, *args: Any) -> Any:
return self.func(*args)
def __str__(self) -> str:
func_name = getattr(self.func, "__name__", "partial_function")
return f"Lambda({func_name}{str(inspect.signature(self.func))})"
class Parallel(Chain):
"""Parallel layer.
This layer calls its sub-modules in parallel with the same inputs, and returns a tuple of their outputs.
Example:
```py
parallel = fl.Parallel(
fl.Linear(32, 64),
fl.Identity(),
fl.Linear(32, 128),
)
tensor = torch.randn(2, 32)
outputs = parallel(tensor)
assert len(outputs) == 3
assert outputs[0].shape == (2, 64)
assert torch.allclose(outputs[1], tensor)
assert outputs[2].shape == (2, 128)
```
"""
_tag = "PAR"
def forward(self, *args: Any) -> tuple[Tensor, ...]:
return tuple([self._call_layer(module, name, *args) for name, module in self._modules.items()])
return tuple(
[
self._call_layer(
module,
name,
*args, # same input for all sub-modules
)
for name, module in self._modules.items()
],
)
def _show_only_tag(self) -> bool:
return self.__class__ == Parallel
class Distribute(Chain):
"""Distribute layer.
This layer calls its sub-modules in parallel with their respective input, and returns a tuple of their outputs.
Example:
```py
distribute = fl.Distribute(
fl.Linear(32, 128),
fl.Linear(64, 256),
)
tensor1 = torch.randn(2, 32)
tensor2 = torch.randn(4, 64)
outputs = distribute(tensor1, tensor2)
assert len(outputs) == 2
assert outputs[0].shape == (2, 128)
assert outputs[1].shape == (4, 256)
```
"""
_tag = "DISTR"
def forward(self, *args: Any) -> tuple[Tensor, ...]:
n, m = len(args), len(self._modules)
assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})."
return tuple([self._call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
return tuple(
[
self._call_layer(
module,
name,
arg, # each sub-module has its own input
)
for arg, (name, module) in zip(args, self._modules.items())
]
)
def _show_only_tag(self) -> bool:
return self.__class__ == Distribute
class Passthrough(Chain):
"""Passthrough layer.
This layer call its sub-modules sequentially, and returns its original inputs,
like an [`Identity`][refiners.fluxion.layers.Identity] layer.
Example:
```py
passthrough = fl.Passthrough(
fl.Linear(32, 128),
fl.ReLU(),
fl.Linear(128, 128),
)
tensor = torch.randn(2, 32)
output = passthrough(tensor)
assert torch.allclose(output, tensor)
```
"""
_tag = "PASS"
def forward(self, *inputs: Any) -> Any:
@ -513,6 +820,24 @@ class Passthrough(Chain):
class Sum(Chain):
"""Summation layer.
This layer calls its sub-modules in parallel with the same inputs, and returns the sum of their outputs.
Example:
```py
summation = fl.Sum(
fl.Multiply(scale=2, bias=1),
fl.Multiply(scale=3, bias=0),
)
tensor = torch.ones(1)
output = summation(tensor)
assert torch.allclose(output, torch.tensor([6.0]))
```
"""
_tag = "SUM"
def forward(self, *inputs: Any) -> Any:
@ -529,6 +854,24 @@ class Sum(Chain):
class Residual(Chain):
"""Residual layer.
This layer calls its sub-modules sequentially, and adds the original input to the output.
Example:
```py
residual = fl.Residual(
fl.Multiply(scale=10),
)
tensor = torch.ones(2, 32)
output = residual(tensor)
assert output.shape == (2, 32)
assert torch.allclose(output, 10 * tensor + tensor)
```
"""
_tag = "RES"
def forward(self, *inputs: Any) -> Any:
@ -536,7 +879,101 @@ class Residual(Chain):
return super().forward(*inputs) + inputs[0]
class Concatenate(Chain):
"""Concatenation layer.
This layer calls its sub-modules in parallel with the same inputs, and returns the concatenation of their outputs.
Example:
```py
concatenate = fl.Concatenate(
fl.Linear(32, 128),
fl.Linear(32, 128),
dim=1,
)
tensor = torch.randn(2, 32)
output = concatenate(tensor)
assert output.shape == (2, 256)
```
"""
_tag = "CAT"
def __init__(self, *modules: Module, dim: int = 0) -> None:
super().__init__(*modules)
self.dim = dim
def forward(self, *args: Any) -> Tensor:
outputs = [module(*args) for module in self]
return cat(
[output for output in outputs if output is not None],
dim=self.dim,
)
def _show_only_tag(self) -> bool:
return self.__class__ == Concatenate
class Matmul(Chain):
"""Matrix multiplication layer.
This layer returns the matrix multiplication of the outputs of its two sub-modules.
Example:
```py
matmul = fl.Matmul(
fl.Identity(),
fl.Multiply(scale=2),
)
tensor = torch.randn(10, 10)
output = matmul(tensor)
expected_output = tensor @ (2 * tensor)
assert torch.allclose(output, expected_output)
```
"""
_tag = "MATMUL"
def __init__(self, input: Module, other: Module) -> None:
super().__init__(
input,
other,
)
def forward(self, *args: Tensor) -> Tensor:
return torch.matmul(
input=self[0](*args),
other=self[1](*args),
)
class ReturnException(Exception):
"""Exception raised when a Return module is encountered."""
def __init__(self, value: Tensor):
self.value = value
class Return(Module):
"""Return layer.
This layer stops the execution of a Chain when encountered.
"""
def forward(self, x: Tensor):
raise ReturnException(x)
class Breakpoint(ContextModule):
"""Breakpoint layer.
This layer pauses the execution when encountered, and opens a debugger.
"""
def __init__(self, vscode: bool = True):
super().__init__()
self.vscode = vscode
@ -549,31 +986,3 @@ class Breakpoint(ContextModule):
else:
breakpoint()
return args[0] if len(args) == 1 else args
class Concatenate(Chain):
_tag = "CAT"
def __init__(self, *modules: Module, dim: int = 0) -> None:
super().__init__(*modules)
self.dim = dim
def forward(self, *args: Any) -> Tensor:
outputs = [module(*args) for module in self]
return cat([output for output in outputs if output is not None], dim=self.dim)
def _show_only_tag(self) -> bool:
return self.__class__ == Concatenate
class Matmul(Chain):
_tag = "MATMUL"
def __init__(self, input: Module, other: Module) -> None:
super().__init__(
input,
other,
)
def forward(self, *args: Tensor) -> Tensor:
return torch.matmul(input=self[0](*args), other=self[1](*args))