Add better tree representation for fluxion Module

This commit is contained in:
Benjamin Trom 2023-09-07 16:15:02 +02:00
parent d9a461e9b5
commit cf43cb191f
4 changed files with 160 additions and 66 deletions

View file

@ -19,10 +19,6 @@ class View(Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x.view(*self.shape) return x.view(*self.shape)
def __repr__(self):
shape_repr = ", ".join([repr(s) for s in self.shape])
return f"{self.__class__.__name__}({shape_repr})"
class Flatten(Module): class Flatten(Module):
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None: def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None:
@ -33,9 +29,6 @@ class Flatten(Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x.flatten(self.start_dim, self.end_dim) return x.flatten(self.start_dim, self.end_dim)
def __repr__(self):
return f"{self.__class__.__name__}(start_dim={repr(self.start_dim)}, end_dim={repr(self.end_dim)})"
class Unflatten(Module): class Unflatten(Module):
def __init__(self, dim: int) -> None: def __init__(self, dim: int) -> None:
@ -45,9 +38,6 @@ class Unflatten(Module):
def forward(self, x: Tensor, sizes: Size) -> Tensor: def forward(self, x: Tensor, sizes: Size) -> Tensor:
return x.unflatten(self.dim, sizes) # type: ignore return x.unflatten(self.dim, sizes) # type: ignore
def __repr__(self):
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
class Reshape(Module): class Reshape(Module):
""" """
@ -62,10 +52,6 @@ class Reshape(Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x.reshape(x.shape[0], *self.shape) return x.reshape(x.shape[0], *self.shape)
def __repr__(self):
shape_repr = ", ".join([repr(s) for s in self.shape])
return f"{self.__class__.__name__}({shape_repr})"
class Transpose(Module): class Transpose(Module):
def __init__(self, dim0: int, dim1: int) -> None: def __init__(self, dim0: int, dim1: int) -> None:
@ -76,9 +62,6 @@ class Transpose(Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x.transpose(self.dim0, self.dim1) return x.transpose(self.dim0, self.dim1)
def __repr__(self):
return f"{self.__class__.__name__}(dim0={repr(self.dim0)}, dim1={repr(self.dim1)})"
class Permute(Module): class Permute(Module):
def __init__(self, *dims: int) -> None: def __init__(self, *dims: int) -> None:
@ -88,10 +71,6 @@ class Permute(Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x.permute(*self.dims) return x.permute(*self.dims)
def __repr__(self):
dims_repr = ", ".join([repr(d) for d in self.dims])
return f"{self.__class__.__name__}({dims_repr})"
class Slicing(Module): class Slicing(Module):
def __init__(self, dim: int, start: int, length: int) -> None: def __init__(self, dim: int, start: int, length: int) -> None:
@ -103,9 +82,6 @@ class Slicing(Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x.narrow(self.dim, self.start, self.length) return x.narrow(self.dim, self.start, self.length)
def __repr__(self):
return f"{self.__class__.__name__}(dim={repr(self.dim)}, start={repr(self.start)}, length={repr(self.length)})"
class Squeeze(Module): class Squeeze(Module):
def __init__(self, dim: int) -> None: def __init__(self, dim: int) -> None:
@ -115,9 +91,6 @@ class Squeeze(Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x.squeeze(self.dim) return x.squeeze(self.dim)
def __repr__(self):
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
class Unsqueeze(Module): class Unsqueeze(Module):
def __init__(self, dim: int) -> None: def __init__(self, dim: int) -> None:
@ -127,9 +100,6 @@ class Unsqueeze(Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x.unsqueeze(self.dim) return x.unsqueeze(self.dim)
def __repr__(self):
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
class Parameter(WeightedModule): class Parameter(WeightedModule):
""" """
@ -138,6 +108,7 @@ class Parameter(WeightedModule):
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__() super().__init__()
self.dims = dims
self.register_parameter("parameter", TorchParameter(randn(*dims, device=device, dtype=dtype))) self.register_parameter("parameter", TorchParameter(randn(*dims, device=device, dtype=dtype)))
@property @property
@ -151,10 +122,6 @@ class Parameter(WeightedModule):
def forward(self, _: Tensor) -> Tensor: def forward(self, _: Tensor) -> Tensor:
return self.parameter return self.parameter
def __repr__(self):
dims_repr = ", ".join([repr(d) for d in list(self.parameter.shape)])
return f"{self.__class__.__name__}({dims_repr}, device={repr(self.device)})"
class Buffer(WeightedModule): class Buffer(WeightedModule):
""" """
@ -165,6 +132,7 @@ class Buffer(WeightedModule):
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__() super().__init__()
self.dims = dims
self.register_buffer("buffer", randn(*dims, device=device, dtype=dtype)) self.register_buffer("buffer", randn(*dims, device=device, dtype=dtype))
@property @property
@ -177,7 +145,3 @@ class Buffer(WeightedModule):
def forward(self, _: Tensor) -> Tensor: def forward(self, _: Tensor) -> Tensor:
return self.buffer return self.buffer
def __repr__(self):
dims_repr = ", ".join([repr(d) for d in list(self.buffer.shape)])
return f"{self.__class__.__name__}({dims_repr}, device={repr(self.device)})"

View file

@ -20,7 +20,7 @@ class Lambda(Module):
def forward(self, *args: Any) -> Any: def forward(self, *args: Any) -> Any:
return self.func(*args) return self.func(*args)
def __repr__(self): def __str__(self) -> str:
func_name = getattr(self.func, "__name__", "partial_function") func_name = getattr(self.func, "__name__", "partial_function")
return f"Lambda({func_name}{str(inspect.signature(self.func))})" return f"Lambda({func_name}{str(inspect.signature(self.func))})"
@ -115,6 +115,7 @@ def structural_copy(m: T) -> T:
class Chain(ContextModule): class Chain(ContextModule):
_modules: dict[str, Module] _modules: dict[str, Module]
_provider: ContextProvider _provider: ContextProvider
_tag = "CHAIN"
def __init__(self, *args: Module | Iterable[Module]) -> None: def __init__(self, *args: Module | Iterable[Module]) -> None:
super().__init__() super().__init__()
@ -235,28 +236,6 @@ class Chain(ContextModule):
def __iter__(self) -> Iterator[Module]: def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values()) return iter(self._modules.values())
def _pretty_print(self, num_tab: int = 0, layer_name: str | None = None) -> str:
layer_name = self.__class__.__name__ if layer_name is None else layer_name
pretty_print = f"{layer_name}:\n"
tab = " " * (num_tab + 4)
module_strings: list[str] = []
for i, (name, module) in enumerate(self._modules.items()):
ident = ("└+" if isinstance(self, Sum) else "└─") if i == 0 else " "
module_str = (
module
if not isinstance(module, Chain)
else (module._pretty_print(len(tab), name) if num_tab < 12 else f"{name}(...)")
)
module_strings.append(f"{tab}{ident} {module_str}")
pretty_print += "\n".join(module_strings)
return pretty_print
def __repr__(self) -> str:
return self._pretty_print()
def __str__(self) -> str:
return f"<{self.__class__.__name__} at {hex(id(self))}>"
def __len__(self) -> int: def __len__(self) -> int:
return len(self._modules) return len(self._modules)
@ -418,25 +397,45 @@ class Chain(ContextModule):
return clone return clone
def _show_only_tag(self) -> bool:
return self.__class__ == Chain
class Parallel(Chain): class Parallel(Chain):
_tag = "PAR"
def forward(self, *args: Any) -> tuple[Tensor, ...]: def forward(self, *args: Any) -> tuple[Tensor, ...]:
return tuple([self.call_layer(module, name, *args) for name, module in self._modules.items()]) return tuple([self.call_layer(module, name, *args) for name, module in self._modules.items()])
def _show_only_tag(self) -> bool:
return self.__class__ == Parallel
class Distribute(Chain): class Distribute(Chain):
_tag = "DISTR"
def forward(self, *args: Any) -> tuple[Tensor, ...]: def forward(self, *args: Any) -> tuple[Tensor, ...]:
assert len(args) == len(self._modules), "Number of positional arguments must match number of sub-modules." assert len(args) == len(self._modules), "Number of positional arguments must match number of sub-modules."
return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())]) return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
def _show_only_tag(self) -> bool:
return self.__class__ == Distribute
class Passthrough(Chain): class Passthrough(Chain):
_tag = "PASS"
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
super().forward(*inputs) super().forward(*inputs)
return inputs return inputs
def _show_only_tag(self) -> bool:
return self.__class__ == Passthrough
class Sum(Chain): class Sum(Chain):
_tag = "SUM"
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
output = None output = None
for layer in self: for layer in self:
@ -446,6 +445,9 @@ class Sum(Chain):
output = layer_output if output is None else output + layer_output output = layer_output if output is None else output + layer_output
return output return output
def _show_only_tag(self) -> bool:
return self.__class__ == Sum
class Residual(Sum): class Residual(Sum):
def __init__(self, *modules: Module) -> None: def __init__(self, *modules: Module) -> None:
@ -468,6 +470,7 @@ class Breakpoint(ContextModule):
class Concatenate(Chain): class Concatenate(Chain):
_tag = "CAT"
structural_attrs = ["dim"] structural_attrs = ["dim"]
def __init__(self, *modules: Module, dim: int = 0) -> None: def __init__(self, *modules: Module, dim: int = 0) -> None:
@ -477,3 +480,6 @@ class Concatenate(Chain):
def forward(self, *args: Any) -> Tensor: def forward(self, *args: Any) -> Tensor:
outputs = [module(*args) for module in self] outputs = [module(*args) for module in self]
return cat([output for output in outputs if output is not None], dim=self.dim) return cat([output for output in outputs if output is not None], dim=self.dim)
def _show_only_tag(self) -> bool:
return self.__class__ == Concatenate

View file

@ -8,11 +8,11 @@ class Conv2d(nn.Conv2d, WeightedModule):
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
kernel_size: int | tuple[int, int], kernel_size: int | tuple[int, int],
stride: int | tuple[int, int] = 1, stride: int | tuple[int, int] = (1, 1),
padding: int | tuple[int, int] | str = 0, padding: int | tuple[int, int] | str = (0, 0),
groups: int = 1, groups: int = 1,
use_bias: bool = True, use_bias: bool = True,
dilation: int | tuple[int, int] = 1, dilation: int | tuple[int, int] = (1, 1),
padding_mode: str = "zeros", padding_mode: str = "zeros",
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
@ -30,6 +30,7 @@ class Conv2d(nn.Conv2d, WeightedModule):
device, device,
dtype, dtype,
) )
self.use_bias = use_bias
class Conv1d(nn.Conv1d, WeightedModule): class Conv1d(nn.Conv1d, WeightedModule):

View file

@ -1,5 +1,6 @@
from inspect import signature, Parameter
from pathlib import Path from pathlib import Path
from typing import Any, Generator, TypeVar from typing import Any, Generator, TypeVar, TypedDict, cast
from torch import device as Device, dtype as DType from torch import device as Device, dtype as DType
from torch.nn.modules.module import Module as TorchModule from torch.nn.modules.module import Module as TorchModule
@ -7,18 +8,20 @@ from torch.nn.modules.module import Module as TorchModule
from refiners.fluxion.utils import load_from_safetensors from refiners.fluxion.utils import load_from_safetensors
from refiners.fluxion.context import Context, ContextProvider from refiners.fluxion.context import Context, ContextProvider
from typing import Callable, TYPE_CHECKING from typing import Callable, TYPE_CHECKING, Sequence
if TYPE_CHECKING: if TYPE_CHECKING:
from refiners.fluxion.layers.chain import Chain from refiners.fluxion.layers.chain import Chain
T = TypeVar("T", bound="Module") T = TypeVar("T", bound="Module")
TContextModule = TypeVar("TContextModule", bound="ContextModule") TContextModule = TypeVar("TContextModule", bound="ContextModule")
BasicType = str | float | int | bool
class Module(TorchModule): class Module(TorchModule):
_parameters: dict[str, Any] _parameters: dict[str, Any]
_buffers: dict[str, Any] _buffers: dict[str, Any]
_tag: str = ""
__getattr__: Callable[["Module", str], Any] # type: ignore __getattr__: Callable[["Module", str], Any] # type: ignore
__setattr__: Callable[["Module", str, Any], None] # type: ignore __setattr__: Callable[["Module", str, Any], None] # type: ignore
@ -37,6 +40,56 @@ class Module(TorchModule):
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
return super().to(device=device, dtype=dtype) # type: ignore return super().to(device=device, dtype=dtype) # type: ignore
def __str__(self) -> str:
basic_attributes_str = ", ".join(
f"{key}={value}" for key, value in self.basic_attributes(init_attrs_only=True).items()
)
result = f"{self.__class__.__name__}({basic_attributes_str})"
return result
def __repr__(self) -> str:
tree = ModuleTree(module=self)
return repr(tree)
def pretty_print(self, depth: int = -1) -> None:
tree = ModuleTree(module=self)
print(tree.generate_tree_repr(tree.root, is_root=True, depth=depth))
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]:
"""Return a dictionary of basic attributes of the module.
Basic attributes are public attributes made of basic types (int, float, str, bool) or a sequence of basic types.
"""
sig = signature(obj=self.__init__)
init_params = set(sig.parameters.keys()) - {"self"}
default_values = {k: v.default for k, v in sig.parameters.items() if v.default is not Parameter.empty}
def is_basic_attribute(key: str, value: Any) -> bool:
if key.startswith("_"):
return False
if isinstance(value, BasicType):
return True
if isinstance(value, Sequence) and all(isinstance(y, BasicType) for y in cast(Sequence[Any], value)):
return True
return False
return {
key: str(object=value)
for key, value in self.__dict__.items()
if is_basic_attribute(key=key, value=value)
and (not init_attrs_only or (key in init_params and value != default_values.get(key)))
}
def _show_only_tag(self) -> bool:
"""Whether to show only the tag when printing the module.
This is useful to distinguish between Chain subclasses that override their forward from one another.
"""
return False
class ContextModule(Module): class ContextModule(Module):
# we store parent into a one element list to avoid pytorch thinking it's a submodule # we store parent into a one element list to avoid pytorch thinking it's a submodule
@ -100,3 +153,73 @@ class WeightedModule(Module):
@property @property
def dtype(self) -> DType: def dtype(self) -> DType:
return self.weight.dtype return self.weight.dtype
class TreeNode(TypedDict):
value: str
children: list["TreeNode"]
class ModuleTree:
def __init__(self, module: Module) -> None:
self.root: TreeNode = self._module_to_tree(module=module)
self._fold_successive_identical(node=self.root)
def __str__(self) -> str:
return f"{self.__class__.__name__}(root={self.root['value']})"
def __repr__(self) -> str:
return self.generate_tree_repr(node=self.root, is_root=True, depth=7)
def generate_tree_repr(
self, node: TreeNode, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
) -> str:
if depth == 0:
return ""
if depth > 0:
depth -= 1
tree_icon: str = "" if is_root else ("└── " if is_last else "├── ")
lines = [f"{prefix}{tree_icon}{node['value']}"]
new_prefix: str = " " if is_last else ""
for i, child in enumerate(iterable=node["children"]):
lines.append(
self.generate_tree_repr(
node=child,
prefix=prefix + new_prefix,
is_last=i == len(node["children"]) - 1,
is_root=False,
depth=depth,
)
)
return "\n".join(filter(bool, lines))
def _module_to_tree(self, module: Module) -> TreeNode:
match (module._tag, module._show_only_tag()): # pyright: ignore[reportPrivateUsage]
case ("", False):
value = str(object=module)
case (_, True):
value = f"({module._tag})" # pyright: ignore[reportPrivateUsage]
case (_, False):
value = f"({module._tag}) {module}" # pyright: ignore[reportPrivateUsage]
node: TreeNode = {"value": value, "children": []}
for child in module.children():
node["children"].append(self._module_to_tree(module=child)) # type: ignore
return node
def _fold_successive_identical(self, node: TreeNode) -> None:
i = 0
while i < len(node["children"]):
j = i
while j < len(node["children"]) and node["children"][i] == node["children"][j]:
j += 1
count = j - i
if count > 1:
node["children"][i]["value"] += f" (x{count})"
del node["children"][i + 1 : j]
self._fold_successive_identical(node=node["children"][i])
i += 1