From cf43cb191fb8217fa833b18f6019cd979c6b1d24 Mon Sep 17 00:00:00 2001 From: Benjamin Trom Date: Thu, 7 Sep 2023 16:15:02 +0200 Subject: [PATCH] Add better tree representation for fluxion Module --- src/refiners/fluxion/layers/basics.py | 40 +------- src/refiners/fluxion/layers/chain.py | 52 ++++++----- src/refiners/fluxion/layers/conv.py | 7 +- src/refiners/fluxion/layers/module.py | 127 +++++++++++++++++++++++++- 4 files changed, 160 insertions(+), 66 deletions(-) diff --git a/src/refiners/fluxion/layers/basics.py b/src/refiners/fluxion/layers/basics.py index ada1979..03b48b8 100644 --- a/src/refiners/fluxion/layers/basics.py +++ b/src/refiners/fluxion/layers/basics.py @@ -19,10 +19,6 @@ class View(Module): def forward(self, x: Tensor) -> Tensor: return x.view(*self.shape) - def __repr__(self): - shape_repr = ", ".join([repr(s) for s in self.shape]) - return f"{self.__class__.__name__}({shape_repr})" - class Flatten(Module): def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None: @@ -33,9 +29,6 @@ class Flatten(Module): def forward(self, x: Tensor) -> Tensor: return x.flatten(self.start_dim, self.end_dim) - def __repr__(self): - return f"{self.__class__.__name__}(start_dim={repr(self.start_dim)}, end_dim={repr(self.end_dim)})" - class Unflatten(Module): def __init__(self, dim: int) -> None: @@ -45,9 +38,6 @@ class Unflatten(Module): def forward(self, x: Tensor, sizes: Size) -> Tensor: return x.unflatten(self.dim, sizes) # type: ignore - def __repr__(self): - return f"{self.__class__.__name__}(dim={repr(self.dim)})" - class Reshape(Module): """ @@ -62,10 +52,6 @@ class Reshape(Module): def forward(self, x: Tensor) -> Tensor: return x.reshape(x.shape[0], *self.shape) - def __repr__(self): - shape_repr = ", ".join([repr(s) for s in self.shape]) - return f"{self.__class__.__name__}({shape_repr})" - class Transpose(Module): def __init__(self, dim0: int, dim1: int) -> None: @@ -76,9 +62,6 @@ class Transpose(Module): def forward(self, x: Tensor) -> Tensor: return x.transpose(self.dim0, self.dim1) - def __repr__(self): - return f"{self.__class__.__name__}(dim0={repr(self.dim0)}, dim1={repr(self.dim1)})" - class Permute(Module): def __init__(self, *dims: int) -> None: @@ -88,10 +71,6 @@ class Permute(Module): def forward(self, x: Tensor) -> Tensor: return x.permute(*self.dims) - def __repr__(self): - dims_repr = ", ".join([repr(d) for d in self.dims]) - return f"{self.__class__.__name__}({dims_repr})" - class Slicing(Module): def __init__(self, dim: int, start: int, length: int) -> None: @@ -103,9 +82,6 @@ class Slicing(Module): def forward(self, x: Tensor) -> Tensor: return x.narrow(self.dim, self.start, self.length) - def __repr__(self): - return f"{self.__class__.__name__}(dim={repr(self.dim)}, start={repr(self.start)}, length={repr(self.length)})" - class Squeeze(Module): def __init__(self, dim: int) -> None: @@ -115,9 +91,6 @@ class Squeeze(Module): def forward(self, x: Tensor) -> Tensor: return x.squeeze(self.dim) - def __repr__(self): - return f"{self.__class__.__name__}(dim={repr(self.dim)})" - class Unsqueeze(Module): def __init__(self, dim: int) -> None: @@ -127,9 +100,6 @@ class Unsqueeze(Module): def forward(self, x: Tensor) -> Tensor: return x.unsqueeze(self.dim) - def __repr__(self): - return f"{self.__class__.__name__}(dim={repr(self.dim)})" - class Parameter(WeightedModule): """ @@ -138,6 +108,7 @@ class Parameter(WeightedModule): def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None: super().__init__() + self.dims = dims self.register_parameter("parameter", TorchParameter(randn(*dims, device=device, dtype=dtype))) @property @@ -151,10 +122,6 @@ class Parameter(WeightedModule): def forward(self, _: Tensor) -> Tensor: return self.parameter - def __repr__(self): - dims_repr = ", ".join([repr(d) for d in list(self.parameter.shape)]) - return f"{self.__class__.__name__}({dims_repr}, device={repr(self.device)})" - class Buffer(WeightedModule): """ @@ -165,6 +132,7 @@ class Buffer(WeightedModule): def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None: super().__init__() + self.dims = dims self.register_buffer("buffer", randn(*dims, device=device, dtype=dtype)) @property @@ -177,7 +145,3 @@ class Buffer(WeightedModule): def forward(self, _: Tensor) -> Tensor: return self.buffer - - def __repr__(self): - dims_repr = ", ".join([repr(d) for d in list(self.buffer.shape)]) - return f"{self.__class__.__name__}({dims_repr}, device={repr(self.device)})" diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index 6669cda..a6081f9 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -20,7 +20,7 @@ class Lambda(Module): def forward(self, *args: Any) -> Any: return self.func(*args) - def __repr__(self): + def __str__(self) -> str: func_name = getattr(self.func, "__name__", "partial_function") return f"Lambda({func_name}{str(inspect.signature(self.func))})" @@ -115,6 +115,7 @@ def structural_copy(m: T) -> T: class Chain(ContextModule): _modules: dict[str, Module] _provider: ContextProvider + _tag = "CHAIN" def __init__(self, *args: Module | Iterable[Module]) -> None: super().__init__() @@ -235,28 +236,6 @@ class Chain(ContextModule): def __iter__(self) -> Iterator[Module]: return iter(self._modules.values()) - def _pretty_print(self, num_tab: int = 0, layer_name: str | None = None) -> str: - layer_name = self.__class__.__name__ if layer_name is None else layer_name - pretty_print = f"{layer_name}:\n" - tab = " " * (num_tab + 4) - module_strings: list[str] = [] - for i, (name, module) in enumerate(self._modules.items()): - ident = ("└+" if isinstance(self, Sum) else "└─") if i == 0 else " " - module_str = ( - module - if not isinstance(module, Chain) - else (module._pretty_print(len(tab), name) if num_tab < 12 else f"{name}(...)") - ) - module_strings.append(f"{tab}{ident} {module_str}") - pretty_print += "\n".join(module_strings) - return pretty_print - - def __repr__(self) -> str: - return self._pretty_print() - - def __str__(self) -> str: - return f"<{self.__class__.__name__} at {hex(id(self))}>" - def __len__(self) -> int: return len(self._modules) @@ -418,25 +397,45 @@ class Chain(ContextModule): return clone + def _show_only_tag(self) -> bool: + return self.__class__ == Chain + class Parallel(Chain): + _tag = "PAR" + def forward(self, *args: Any) -> tuple[Tensor, ...]: return tuple([self.call_layer(module, name, *args) for name, module in self._modules.items()]) + def _show_only_tag(self) -> bool: + return self.__class__ == Parallel + class Distribute(Chain): + _tag = "DISTR" + def forward(self, *args: Any) -> tuple[Tensor, ...]: assert len(args) == len(self._modules), "Number of positional arguments must match number of sub-modules." return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())]) + def _show_only_tag(self) -> bool: + return self.__class__ == Distribute + class Passthrough(Chain): + _tag = "PASS" + def forward(self, *inputs: Any) -> Any: super().forward(*inputs) return inputs + def _show_only_tag(self) -> bool: + return self.__class__ == Passthrough + class Sum(Chain): + _tag = "SUM" + def forward(self, *inputs: Any) -> Any: output = None for layer in self: @@ -446,6 +445,9 @@ class Sum(Chain): output = layer_output if output is None else output + layer_output return output + def _show_only_tag(self) -> bool: + return self.__class__ == Sum + class Residual(Sum): def __init__(self, *modules: Module) -> None: @@ -468,6 +470,7 @@ class Breakpoint(ContextModule): class Concatenate(Chain): + _tag = "CAT" structural_attrs = ["dim"] def __init__(self, *modules: Module, dim: int = 0) -> None: @@ -477,3 +480,6 @@ class Concatenate(Chain): def forward(self, *args: Any) -> Tensor: outputs = [module(*args) for module in self] return cat([output for output in outputs if output is not None], dim=self.dim) + + def _show_only_tag(self) -> bool: + return self.__class__ == Concatenate diff --git a/src/refiners/fluxion/layers/conv.py b/src/refiners/fluxion/layers/conv.py index fd24308..baab860 100644 --- a/src/refiners/fluxion/layers/conv.py +++ b/src/refiners/fluxion/layers/conv.py @@ -8,11 +8,11 @@ class Conv2d(nn.Conv2d, WeightedModule): in_channels: int, out_channels: int, kernel_size: int | tuple[int, int], - stride: int | tuple[int, int] = 1, - padding: int | tuple[int, int] | str = 0, + stride: int | tuple[int, int] = (1, 1), + padding: int | tuple[int, int] | str = (0, 0), groups: int = 1, use_bias: bool = True, - dilation: int | tuple[int, int] = 1, + dilation: int | tuple[int, int] = (1, 1), padding_mode: str = "zeros", device: Device | str | None = None, dtype: DType | None = None, @@ -30,6 +30,7 @@ class Conv2d(nn.Conv2d, WeightedModule): device, dtype, ) + self.use_bias = use_bias class Conv1d(nn.Conv1d, WeightedModule): diff --git a/src/refiners/fluxion/layers/module.py b/src/refiners/fluxion/layers/module.py index edf864b..5a4ea2d 100644 --- a/src/refiners/fluxion/layers/module.py +++ b/src/refiners/fluxion/layers/module.py @@ -1,5 +1,6 @@ +from inspect import signature, Parameter from pathlib import Path -from typing import Any, Generator, TypeVar +from typing import Any, Generator, TypeVar, TypedDict, cast from torch import device as Device, dtype as DType from torch.nn.modules.module import Module as TorchModule @@ -7,18 +8,20 @@ from torch.nn.modules.module import Module as TorchModule from refiners.fluxion.utils import load_from_safetensors from refiners.fluxion.context import Context, ContextProvider -from typing import Callable, TYPE_CHECKING +from typing import Callable, TYPE_CHECKING, Sequence if TYPE_CHECKING: from refiners.fluxion.layers.chain import Chain T = TypeVar("T", bound="Module") TContextModule = TypeVar("TContextModule", bound="ContextModule") +BasicType = str | float | int | bool class Module(TorchModule): _parameters: dict[str, Any] _buffers: dict[str, Any] + _tag: str = "" __getattr__: Callable[["Module", str], Any] # type: ignore __setattr__: Callable[["Module", str, Any], None] # type: ignore @@ -37,6 +40,56 @@ class Module(TorchModule): def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore return super().to(device=device, dtype=dtype) # type: ignore + def __str__(self) -> str: + basic_attributes_str = ", ".join( + f"{key}={value}" for key, value in self.basic_attributes(init_attrs_only=True).items() + ) + result = f"{self.__class__.__name__}({basic_attributes_str})" + return result + + def __repr__(self) -> str: + tree = ModuleTree(module=self) + return repr(tree) + + def pretty_print(self, depth: int = -1) -> None: + tree = ModuleTree(module=self) + print(tree.generate_tree_repr(tree.root, is_root=True, depth=depth)) + + def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]: + """Return a dictionary of basic attributes of the module. + + Basic attributes are public attributes made of basic types (int, float, str, bool) or a sequence of basic types. + """ + sig = signature(obj=self.__init__) + init_params = set(sig.parameters.keys()) - {"self"} + default_values = {k: v.default for k, v in sig.parameters.items() if v.default is not Parameter.empty} + + def is_basic_attribute(key: str, value: Any) -> bool: + if key.startswith("_"): + return False + + if isinstance(value, BasicType): + return True + + if isinstance(value, Sequence) and all(isinstance(y, BasicType) for y in cast(Sequence[Any], value)): + return True + + return False + + return { + key: str(object=value) + for key, value in self.__dict__.items() + if is_basic_attribute(key=key, value=value) + and (not init_attrs_only or (key in init_params and value != default_values.get(key))) + } + + def _show_only_tag(self) -> bool: + """Whether to show only the tag when printing the module. + + This is useful to distinguish between Chain subclasses that override their forward from one another. + """ + return False + class ContextModule(Module): # we store parent into a one element list to avoid pytorch thinking it's a submodule @@ -100,3 +153,73 @@ class WeightedModule(Module): @property def dtype(self) -> DType: return self.weight.dtype + + +class TreeNode(TypedDict): + value: str + children: list["TreeNode"] + + +class ModuleTree: + def __init__(self, module: Module) -> None: + self.root: TreeNode = self._module_to_tree(module=module) + self._fold_successive_identical(node=self.root) + + def __str__(self) -> str: + return f"{self.__class__.__name__}(root={self.root['value']})" + + def __repr__(self) -> str: + return self.generate_tree_repr(node=self.root, is_root=True, depth=7) + + def generate_tree_repr( + self, node: TreeNode, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1 + ) -> str: + if depth == 0: + return "" + + if depth > 0: + depth -= 1 + + tree_icon: str = "" if is_root else ("└── " if is_last else "├── ") + lines = [f"{prefix}{tree_icon}{node['value']}"] + new_prefix: str = " " if is_last else "│ " + + for i, child in enumerate(iterable=node["children"]): + lines.append( + self.generate_tree_repr( + node=child, + prefix=prefix + new_prefix, + is_last=i == len(node["children"]) - 1, + is_root=False, + depth=depth, + ) + ) + + return "\n".join(filter(bool, lines)) + + def _module_to_tree(self, module: Module) -> TreeNode: + match (module._tag, module._show_only_tag()): # pyright: ignore[reportPrivateUsage] + case ("", False): + value = str(object=module) + case (_, True): + value = f"({module._tag})" # pyright: ignore[reportPrivateUsage] + case (_, False): + value = f"({module._tag}) {module}" # pyright: ignore[reportPrivateUsage] + + node: TreeNode = {"value": value, "children": []} + for child in module.children(): + node["children"].append(self._module_to_tree(module=child)) # type: ignore + return node + + def _fold_successive_identical(self, node: TreeNode) -> None: + i = 0 + while i < len(node["children"]): + j = i + while j < len(node["children"]) and node["children"][i] == node["children"][j]: + j += 1 + count = j - i + if count > 1: + node["children"][i]["value"] += f" (x{count})" + del node["children"][i + 1 : j] + self._fold_successive_identical(node=node["children"][i]) + i += 1