diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index 6273c94..e866ae3 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -16,24 +16,14 @@ T = TypeVar("T", bound=Module) TChain = TypeVar("TChain", bound="Chain") # because Self (PEP 673) is not in 3.10 -class Lambda(Module): - """Lambda is a wrapper around a callable object that allows it to be used as a PyTorch module.""" - - def __init__(self, func: Callable[..., Any]) -> None: - super().__init__() - self.func = func - - def forward(self, *args: Any) -> Any: - return self.func(*args) - - def __str__(self) -> str: - func_name = getattr(self.func, "__name__", "partial_function") - return f"Lambda({func_name}{str(inspect.signature(self.func))})" - - def generate_unique_names( modules: tuple[Module, ...], ) -> dict[str, Module]: + """Generate unique names for each Module in a sequence. + + Args: + modules: The sequence of Modules to name. + """ class_counts: dict[str, int] = {} unique_names: list[tuple[str, Module]] = [] for module in modules: @@ -48,69 +38,8 @@ def generate_unique_names( return dict(unique_names) -class UseContext(ContextModule): - def __init__(self, context: str, key: str) -> None: - super().__init__() - self.context = context - self.key = key - self.func: Callable[[Any], Any] = lambda x: x - - def __call__(self, *args: Any) -> Any: - context = self.use_context(self.context) - assert context, f"context {self.context} is unset" - value = context.get(self.key) - assert value is not None, f"context entry {self.context}.{self.key} is unset" - return self.func(value) - - def __repr__(self): - return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})" - - def compose(self, func: Callable[[Any], Any]) -> "UseContext": - self.func = func - return self - - -class SetContext(ContextModule): - """A Module that sets a context value when executed. - - The context need to pre exist in the context provider. - #TODO Is there a way to create the context if it doesn't exist? - """ - - def __init__(self, context: str, key: str, callback: Callable[[Any, Any], Any] | None = None) -> None: - super().__init__() - self.context = context - self.key = key - self.callback = callback - - def __call__(self, x: Tensor) -> Tensor: - if context := self.use_context(self.context): - if not self.callback: - context.update({self.key: x}) - else: - self.callback(context[self.key], x) - - return x - - def __repr__(self): - return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})" - - -class ReturnException(Exception): - """Exception raised when a Return module is encountered.""" - - def __init__(self, value: Tensor): - self.value = value - - -class Return(Module): - """A Module that stops the execution of a Chain when encountered.""" - - def forward(self, x: Tensor): - raise ReturnException(x) - - def structural_copy(m: T) -> T: + """Helper function to copy a Module's tree, only if it is a ContextModule instance.""" return m.structural_copy() if isinstance(m, ContextModule) else m @@ -122,6 +51,29 @@ class ChainError(RuntimeError): class Chain(ContextModule): + """Chain layer. + + This layer is the main building block of Fluxion. + It is used to compose other layers in a sequential manner. + Similary to [`torch.nn.Sequential`][torch.nn.Sequential], + it calls each of its sub-layers in order, chaining their outputs as inputs to the next sublayer. + However, it also provides additional methods to manipulate its sub-layers and their context. + + Example: + ```py + chain = fl.Chain( + fl.Linear(32, 64), + fl.ReLU(), + fl.Linear(64, 128), + ) + + tensor = torch.randn(2, 32) + output = chain(tensor) + + assert output.shape == (2, 128) + ``` + """ + _modules: dict[str, Module] _provider: ContextProvider _tag = "CHAIN" @@ -165,12 +117,23 @@ class Chain(ContextModule): @property def provider(self) -> ContextProvider: + """The [`ContextProvider`][refiners.fluxion.context.ContextProvider] of the Chain.""" return self._provider def init_context(self) -> Contexts: + """Initialize the context provider with some default values. + + This method is called when the Chain is created, and when it is reset. + This method may be overridden by subclasses to provide default values for the context provider. + """ return {} - def _register_provider(self, context: Contexts | None = None) -> None: + def _register_provider(self, context: Contexts | None = None) -> None: # TODO: rename me ? + """Recursively update the context provider to all sub-modules. + + Args: + context: The context to use to update the provider. + """ if context: self._provider.update_contexts(context) @@ -179,9 +142,16 @@ class Chain(ContextModule): module._register_provider(context=self._provider.contexts) def _reset_context(self) -> None: + """Reset the context provider to its initial state.""" self._register_provider(self.init_context()) def set_context(self, context: str, value: Any) -> None: + """Set a value in the context provider. + + Args: + context: The context to update. + value: The value to set. + """ self._provider.set_context(context, value) self._register_provider() @@ -315,18 +285,27 @@ class Chain(ContextModule): @property def device(self) -> Device | None: + """The PyTorch device of the Chain's parameters.""" wm = self.find(WeightedModule) return None if wm is None else wm.device @property def dtype(self) -> DType | None: + """The PyTorch dtype of the Chain's parameters.""" wm = self.find(WeightedModule) return None if wm is None else wm.dtype def _walk( - self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False + self, + predicate: Callable[[Module, "Chain"], bool] | None = None, + recurse: bool = False, ) -> Iterator[tuple[Module, "Chain"]]: + """Walk the Chain's sub-module tree and yield each module that matches the predicate. + + The predicate is a (Module, Chain) -> bool function. + """ if predicate is None: + # if no predicate is given, yield all modules predicate = lambda _m, _p: True for module in self: try: @@ -342,35 +321,100 @@ class Chain(ContextModule): @overload def walk( - self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False + self, + predicate: Callable[[Module, "Chain"], bool] | None = None, + recurse: bool = False, ) -> Iterator[tuple[Module, "Chain"]]: ... @overload - def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]: + def walk( + self, + predicate: type[T], + recurse: bool = False, + ) -> Iterator[tuple[T, "Chain"]]: ... def walk( - self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False + self, + predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, + recurse: bool = False, ) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]: - if isinstance(predicate, type): - return self._walk(lambda m, _: isinstance(m, predicate), recurse) - else: - return self._walk(predicate, recurse) + """Walk the Chain's sub-module tree and yield each module that matches the predicate. - def layers(self, layer_type: type[T], recurse: bool = False) -> Iterator[T]: + Args: + predicate: The predicate to match. + recurse: Whether to recurse into sub-Chains. + + Yields: + Each module that matches the predicate. + """ + if isinstance(predicate, type): + # if the predicate is a Module type + # build a predicate function that matches the type + return self._walk( + predicate=lambda m, _: isinstance(m, predicate), + recurse=recurse, + ) + else: + return self._walk( + predicate=predicate, + recurse=recurse, + ) + + def layers( + self, + layer_type: type[T], + recurse: bool = False, + ) -> Iterator[T]: + """Walk the Chain's sub-module tree and yield each layer of the given type. + + Args: + layer_type: The type of layer to yield. + recurse: Whether to recurse into sub-Chains. + + Yields: + Each module of the given layer_type. + """ for module, _ in self.walk(layer_type, recurse): yield module def find(self, layer_type: type[T]) -> T | None: + """Walk the Chain's sub-module tree and return the first layer of the given type. + + Args: + layer_type: The type of layer to find. + + Returns: + The first module of the given layer_type, or None if it doesn't exist. + """ return next(self.layers(layer_type=layer_type), None) def ensure_find(self, layer_type: type[T]) -> T: + """Walk the Chain's sub-module tree and return the first layer of the given type. + + Args: + layer_type: The type of layer to find. + + Returns: + The first module of the given layer_type. + + Raises: + AssertionError: If the module doesn't exist. + """ r = self.find(layer_type) assert r is not None, f"could not find {layer_type} in {self}" return r def find_parent(self, module: Module) -> "Chain | None": + """Walk the Chain's sub-module tree and return the parent of the given module. + + Args: + module: The module whose parent to find. + + Returns: + The parent of the given module, or None if it doesn't exist. + """ if module in self: # avoid DFS-crawling the whole tree return self for _, parent in self.walk(lambda m, _: m == module): @@ -378,11 +422,31 @@ class Chain(ContextModule): return None def ensure_find_parent(self, module: Module) -> "Chain": + """Walk the Chain's sub-module tree and return the parent of the given module. + + Args: + module: The module whose parent to find. + + Returns: + The parent of the given module. + + Raises: + AssertionError: If the module doesn't exist. + """ r = self.find_parent(module) assert r is not None, f"could not find {module} in {self}" return r def insert(self, index: int, module: Module) -> None: + """Insert a new module in the chain. + + Args: + index: The index at which to insert the module. + module: The module to insert. + + Raises: + IndexError: If the index is out of range. + """ if index < 0: index = max(0, len(self._modules) + index + 1) modules = list(self) @@ -393,6 +457,15 @@ class Chain(ContextModule): self._register_provider() def insert_before_type(self, module_type: type[Module], new_module: Module) -> None: + """Insert a new module in the chain, right before the first module of the given type. + + Args: + module_type: The type of module to insert before. + new_module: The module to insert. + + Raises: + ValueError: If no module of the given type exists in the chain. + """ for i, module in enumerate(self): if isinstance(module, module_type): self.insert(i, new_module) @@ -400,6 +473,15 @@ class Chain(ContextModule): raise ValueError(f"No module of type {module_type.__name__} found in the chain.") def insert_after_type(self, module_type: type[Module], new_module: Module) -> None: + """Insert a new module in the chain, right after the first module of the given type. + + Args: + module_type: The type of module to insert after. + new_module: The module to insert. + + Raises: + ValueError: If no module of the given type exists in the chain. + """ for i, module in enumerate(self): if isinstance(module, module_type): self.insert(i + 1, new_module) @@ -407,9 +489,25 @@ class Chain(ContextModule): raise ValueError(f"No module of type {module_type.__name__} found in the chain.") def append(self, module: Module) -> None: + """Append a new module to the chain. + + Args: + module: The module to append. + """ self.insert(-1, module) def pop(self, index: int = -1) -> Module | tuple[Module]: + """Pop a module from the chain at the given index. + + Args: + index: The index of the module to pop. + + Returns: + The popped module. + + Raises: + IndexError: If the index is out of range. + """ modules = list(self) if index < 0: index = len(modules) + index @@ -422,7 +520,14 @@ class Chain(ContextModule): return removed_module def remove(self, module: Module) -> None: - """Remove a module from the chain.""" + """Remove a module from the chain. + + Args: + module: The module to remove. + + Raises: + ValueError: If the module is not in the chain. + """ modules = list(self) try: modules.remove(module) @@ -438,7 +543,17 @@ class Chain(ContextModule): new_module: Module, old_module_parent: "Chain | None" = None, ) -> None: - """Replace a module in the chain with a new module.""" + """Replace a module in the chain with a new module. + + Args: + old_module: The module to replace. + new_module: The module to replace with. + old_module_parent: The parent of the old module. + If None, the old module is orphanized. + + Raises: + ValueError: If the module is not in the chain. + """ modules = list(self) try: modules[modules.index(old_module)] = new_module @@ -479,29 +594,221 @@ class Chain(ContextModule): return self.__class__ == Chain +class UseContext(ContextModule): + """UseContext layer. + + This layer reads from the [`ContextProvider`][refiners.fluxion.context.ContextProvider] + of its parent [`Chain`][refiners.fluxion.layers.chain.Chain]. + + Note: When called, it will + - Retrieve a value from the context using the given key + - Transform the value with the given function (optional) + - Return the value + """ + + def __init__(self, context: str, key: str) -> None: + super().__init__() + self.context = context + self.key = key + self.func: Callable[[Any], Any] = lambda x: x + + def __call__(self, *args: Any) -> Any: + context = self.use_context(self.context) + assert context, f"context {self.context} is unset" + value = context.get(self.key) + assert value is not None, f"context entry {self.context}.{self.key} is unset" + return self.func(value) + + def __repr__(self): + return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})" + + def compose(self, func: Callable[[Any], Any]) -> "UseContext": + self.func = func + return self + + +class SetContext(ContextModule): + """SetContext layer. + + This layer writes to the [`ContextProvider`][refiners.fluxion.context.ContextProvider] + of its parent [`Chain`][refiners.fluxion.layers.chain.Chain]. + + Note: When called (without a callback), it will + - Update the context with the given key and the input value + - Return the input value + + Note: When called (with a callback), it will + - Call the callback with the current value and the input value + (the callback may update the context with a new value, or not) + - Return the input value + + Warning: + The context needs to already exist in the [`ContextProvider`][refiners.fluxion.context.ContextProvider] + """ + + # TODO: Create the context if it doesn't exist + + def __init__( + self, + context: str, + key: str, + callback: Callable[[Any, Any], Any] | None = None, + ) -> None: + super().__init__() + self.context = context + self.key = key + self.callback = callback + + def __call__(self, x: Tensor) -> Tensor: + if context := self.use_context(self.context): + if not self.callback: + context.update({self.key: x}) + else: + self.callback(context[self.key], x) + + return x + + def __repr__(self): + return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})" + + +class Lambda(Module): + """Lambda layer. + + This layer wraps a [`Callable`][typing.Callable]. + + Note: When called, it will + - Execute the [`Callable`][typing.Callable] with the given arguments + - Return the output of the [`Callable`][typing.Callable]) + + Example: + ```py + lambda_layer = fl.Lambda(lambda x: x + 1) + + tensor = torch.tensor([1, 2, 3]) + output = lambda_layer(tensor) + + expected_output = torch.tensor([2, 3, 4]) + assert torch.allclose(output, expected_output) + ``` + """ + + def __init__(self, func: Callable[..., Any]) -> None: + super().__init__() + self.func = func + + def forward(self, *args: Any) -> Any: + return self.func(*args) + + def __str__(self) -> str: + func_name = getattr(self.func, "__name__", "partial_function") + return f"Lambda({func_name}{str(inspect.signature(self.func))})" + + class Parallel(Chain): + """Parallel layer. + + This layer calls its sub-modules in parallel with the same inputs, and returns a tuple of their outputs. + + Example: + ```py + parallel = fl.Parallel( + fl.Linear(32, 64), + fl.Identity(), + fl.Linear(32, 128), + ) + + tensor = torch.randn(2, 32) + outputs = parallel(tensor) + + assert len(outputs) == 3 + assert outputs[0].shape == (2, 64) + assert torch.allclose(outputs[1], tensor) + assert outputs[2].shape == (2, 128) + ``` + """ + _tag = "PAR" def forward(self, *args: Any) -> tuple[Tensor, ...]: - return tuple([self._call_layer(module, name, *args) for name, module in self._modules.items()]) + return tuple( + [ + self._call_layer( + module, + name, + *args, # same input for all sub-modules + ) + for name, module in self._modules.items() + ], + ) def _show_only_tag(self) -> bool: return self.__class__ == Parallel class Distribute(Chain): + """Distribute layer. + + This layer calls its sub-modules in parallel with their respective input, and returns a tuple of their outputs. + + Example: + ```py + distribute = fl.Distribute( + fl.Linear(32, 128), + fl.Linear(64, 256), + ) + + tensor1 = torch.randn(2, 32) + tensor2 = torch.randn(4, 64) + outputs = distribute(tensor1, tensor2) + + assert len(outputs) == 2 + assert outputs[0].shape == (2, 128) + assert outputs[1].shape == (4, 256) + ``` + """ + _tag = "DISTR" def forward(self, *args: Any) -> tuple[Tensor, ...]: n, m = len(args), len(self._modules) assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})." - return tuple([self._call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())]) + return tuple( + [ + self._call_layer( + module, + name, + arg, # each sub-module has its own input + ) + for arg, (name, module) in zip(args, self._modules.items()) + ] + ) def _show_only_tag(self) -> bool: return self.__class__ == Distribute class Passthrough(Chain): + """Passthrough layer. + + This layer call its sub-modules sequentially, and returns its original inputs, + like an [`Identity`][refiners.fluxion.layers.Identity] layer. + + Example: + ```py + passthrough = fl.Passthrough( + fl.Linear(32, 128), + fl.ReLU(), + fl.Linear(128, 128), + ) + + tensor = torch.randn(2, 32) + output = passthrough(tensor) + + assert torch.allclose(output, tensor) + ``` + """ + _tag = "PASS" def forward(self, *inputs: Any) -> Any: @@ -513,6 +820,24 @@ class Passthrough(Chain): class Sum(Chain): + """Summation layer. + + This layer calls its sub-modules in parallel with the same inputs, and returns the sum of their outputs. + + Example: + ```py + summation = fl.Sum( + fl.Multiply(scale=2, bias=1), + fl.Multiply(scale=3, bias=0), + ) + + tensor = torch.ones(1) + output = summation(tensor) + + assert torch.allclose(output, torch.tensor([6.0])) + ``` + """ + _tag = "SUM" def forward(self, *inputs: Any) -> Any: @@ -529,6 +854,24 @@ class Sum(Chain): class Residual(Chain): + """Residual layer. + + This layer calls its sub-modules sequentially, and adds the original input to the output. + + Example: + ```py + residual = fl.Residual( + fl.Multiply(scale=10), + ) + + tensor = torch.ones(2, 32) + output = residual(tensor) + + assert output.shape == (2, 32) + assert torch.allclose(output, 10 * tensor + tensor) + ``` + """ + _tag = "RES" def forward(self, *inputs: Any) -> Any: @@ -536,7 +879,101 @@ class Residual(Chain): return super().forward(*inputs) + inputs[0] +class Concatenate(Chain): + """Concatenation layer. + + This layer calls its sub-modules in parallel with the same inputs, and returns the concatenation of their outputs. + + Example: + ```py + concatenate = fl.Concatenate( + fl.Linear(32, 128), + fl.Linear(32, 128), + dim=1, + ) + + tensor = torch.randn(2, 32) + output = concatenate(tensor) + + assert output.shape == (2, 256) + ``` + """ + + _tag = "CAT" + + def __init__(self, *modules: Module, dim: int = 0) -> None: + super().__init__(*modules) + self.dim = dim + + def forward(self, *args: Any) -> Tensor: + outputs = [module(*args) for module in self] + return cat( + [output for output in outputs if output is not None], + dim=self.dim, + ) + + def _show_only_tag(self) -> bool: + return self.__class__ == Concatenate + + +class Matmul(Chain): + """Matrix multiplication layer. + + This layer returns the matrix multiplication of the outputs of its two sub-modules. + + Example: + ```py + matmul = fl.Matmul( + fl.Identity(), + fl.Multiply(scale=2), + ) + + tensor = torch.randn(10, 10) + output = matmul(tensor) + + expected_output = tensor @ (2 * tensor) + assert torch.allclose(output, expected_output) + ``` + """ + + _tag = "MATMUL" + + def __init__(self, input: Module, other: Module) -> None: + super().__init__( + input, + other, + ) + + def forward(self, *args: Tensor) -> Tensor: + return torch.matmul( + input=self[0](*args), + other=self[1](*args), + ) + + +class ReturnException(Exception): + """Exception raised when a Return module is encountered.""" + + def __init__(self, value: Tensor): + self.value = value + + +class Return(Module): + """Return layer. + + This layer stops the execution of a Chain when encountered. + """ + + def forward(self, x: Tensor): + raise ReturnException(x) + + class Breakpoint(ContextModule): + """Breakpoint layer. + + This layer pauses the execution when encountered, and opens a debugger. + """ + def __init__(self, vscode: bool = True): super().__init__() self.vscode = vscode @@ -549,31 +986,3 @@ class Breakpoint(ContextModule): else: breakpoint() return args[0] if len(args) == 1 else args - - -class Concatenate(Chain): - _tag = "CAT" - - def __init__(self, *modules: Module, dim: int = 0) -> None: - super().__init__(*modules) - self.dim = dim - - def forward(self, *args: Any) -> Tensor: - outputs = [module(*args) for module in self] - return cat([output for output in outputs if output is not None], dim=self.dim) - - def _show_only_tag(self) -> bool: - return self.__class__ == Concatenate - - -class Matmul(Chain): - _tag = "MATMUL" - - def __init__(self, input: Module, other: Module) -> None: - super().__init__( - input, - other, - ) - - def forward(self, *args: Tensor) -> Tensor: - return torch.matmul(input=self[0](*args), other=self[1](*args))