Add better tree representation for fluxion Module

This commit is contained in:
Benjamin Trom 2023-09-07 16:15:02 +02:00
parent d9a461e9b5
commit cf43cb191f
4 changed files with 160 additions and 66 deletions

View file

@ -19,10 +19,6 @@ class View(Module):
def forward(self, x: Tensor) -> Tensor:
return x.view(*self.shape)
def __repr__(self):
shape_repr = ", ".join([repr(s) for s in self.shape])
return f"{self.__class__.__name__}({shape_repr})"
class Flatten(Module):
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None:
@ -33,9 +29,6 @@ class Flatten(Module):
def forward(self, x: Tensor) -> Tensor:
return x.flatten(self.start_dim, self.end_dim)
def __repr__(self):
return f"{self.__class__.__name__}(start_dim={repr(self.start_dim)}, end_dim={repr(self.end_dim)})"
class Unflatten(Module):
def __init__(self, dim: int) -> None:
@ -45,9 +38,6 @@ class Unflatten(Module):
def forward(self, x: Tensor, sizes: Size) -> Tensor:
return x.unflatten(self.dim, sizes) # type: ignore
def __repr__(self):
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
class Reshape(Module):
"""
@ -62,10 +52,6 @@ class Reshape(Module):
def forward(self, x: Tensor) -> Tensor:
return x.reshape(x.shape[0], *self.shape)
def __repr__(self):
shape_repr = ", ".join([repr(s) for s in self.shape])
return f"{self.__class__.__name__}({shape_repr})"
class Transpose(Module):
def __init__(self, dim0: int, dim1: int) -> None:
@ -76,9 +62,6 @@ class Transpose(Module):
def forward(self, x: Tensor) -> Tensor:
return x.transpose(self.dim0, self.dim1)
def __repr__(self):
return f"{self.__class__.__name__}(dim0={repr(self.dim0)}, dim1={repr(self.dim1)})"
class Permute(Module):
def __init__(self, *dims: int) -> None:
@ -88,10 +71,6 @@ class Permute(Module):
def forward(self, x: Tensor) -> Tensor:
return x.permute(*self.dims)
def __repr__(self):
dims_repr = ", ".join([repr(d) for d in self.dims])
return f"{self.__class__.__name__}({dims_repr})"
class Slicing(Module):
def __init__(self, dim: int, start: int, length: int) -> None:
@ -103,9 +82,6 @@ class Slicing(Module):
def forward(self, x: Tensor) -> Tensor:
return x.narrow(self.dim, self.start, self.length)
def __repr__(self):
return f"{self.__class__.__name__}(dim={repr(self.dim)}, start={repr(self.start)}, length={repr(self.length)})"
class Squeeze(Module):
def __init__(self, dim: int) -> None:
@ -115,9 +91,6 @@ class Squeeze(Module):
def forward(self, x: Tensor) -> Tensor:
return x.squeeze(self.dim)
def __repr__(self):
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
class Unsqueeze(Module):
def __init__(self, dim: int) -> None:
@ -127,9 +100,6 @@ class Unsqueeze(Module):
def forward(self, x: Tensor) -> Tensor:
return x.unsqueeze(self.dim)
def __repr__(self):
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
class Parameter(WeightedModule):
"""
@ -138,6 +108,7 @@ class Parameter(WeightedModule):
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__()
self.dims = dims
self.register_parameter("parameter", TorchParameter(randn(*dims, device=device, dtype=dtype)))
@property
@ -151,10 +122,6 @@ class Parameter(WeightedModule):
def forward(self, _: Tensor) -> Tensor:
return self.parameter
def __repr__(self):
dims_repr = ", ".join([repr(d) for d in list(self.parameter.shape)])
return f"{self.__class__.__name__}({dims_repr}, device={repr(self.device)})"
class Buffer(WeightedModule):
"""
@ -165,6 +132,7 @@ class Buffer(WeightedModule):
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__()
self.dims = dims
self.register_buffer("buffer", randn(*dims, device=device, dtype=dtype))
@property
@ -177,7 +145,3 @@ class Buffer(WeightedModule):
def forward(self, _: Tensor) -> Tensor:
return self.buffer
def __repr__(self):
dims_repr = ", ".join([repr(d) for d in list(self.buffer.shape)])
return f"{self.__class__.__name__}({dims_repr}, device={repr(self.device)})"

View file

@ -20,7 +20,7 @@ class Lambda(Module):
def forward(self, *args: Any) -> Any:
return self.func(*args)
def __repr__(self):
def __str__(self) -> str:
func_name = getattr(self.func, "__name__", "partial_function")
return f"Lambda({func_name}{str(inspect.signature(self.func))})"
@ -115,6 +115,7 @@ def structural_copy(m: T) -> T:
class Chain(ContextModule):
_modules: dict[str, Module]
_provider: ContextProvider
_tag = "CHAIN"
def __init__(self, *args: Module | Iterable[Module]) -> None:
super().__init__()
@ -235,28 +236,6 @@ class Chain(ContextModule):
def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values())
def _pretty_print(self, num_tab: int = 0, layer_name: str | None = None) -> str:
layer_name = self.__class__.__name__ if layer_name is None else layer_name
pretty_print = f"{layer_name}:\n"
tab = " " * (num_tab + 4)
module_strings: list[str] = []
for i, (name, module) in enumerate(self._modules.items()):
ident = ("└+" if isinstance(self, Sum) else "└─") if i == 0 else " "
module_str = (
module
if not isinstance(module, Chain)
else (module._pretty_print(len(tab), name) if num_tab < 12 else f"{name}(...)")
)
module_strings.append(f"{tab}{ident} {module_str}")
pretty_print += "\n".join(module_strings)
return pretty_print
def __repr__(self) -> str:
return self._pretty_print()
def __str__(self) -> str:
return f"<{self.__class__.__name__} at {hex(id(self))}>"
def __len__(self) -> int:
return len(self._modules)
@ -418,25 +397,45 @@ class Chain(ContextModule):
return clone
def _show_only_tag(self) -> bool:
return self.__class__ == Chain
class Parallel(Chain):
_tag = "PAR"
def forward(self, *args: Any) -> tuple[Tensor, ...]:
return tuple([self.call_layer(module, name, *args) for name, module in self._modules.items()])
def _show_only_tag(self) -> bool:
return self.__class__ == Parallel
class Distribute(Chain):
_tag = "DISTR"
def forward(self, *args: Any) -> tuple[Tensor, ...]:
assert len(args) == len(self._modules), "Number of positional arguments must match number of sub-modules."
return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
def _show_only_tag(self) -> bool:
return self.__class__ == Distribute
class Passthrough(Chain):
_tag = "PASS"
def forward(self, *inputs: Any) -> Any:
super().forward(*inputs)
return inputs
def _show_only_tag(self) -> bool:
return self.__class__ == Passthrough
class Sum(Chain):
_tag = "SUM"
def forward(self, *inputs: Any) -> Any:
output = None
for layer in self:
@ -446,6 +445,9 @@ class Sum(Chain):
output = layer_output if output is None else output + layer_output
return output
def _show_only_tag(self) -> bool:
return self.__class__ == Sum
class Residual(Sum):
def __init__(self, *modules: Module) -> None:
@ -468,6 +470,7 @@ class Breakpoint(ContextModule):
class Concatenate(Chain):
_tag = "CAT"
structural_attrs = ["dim"]
def __init__(self, *modules: Module, dim: int = 0) -> None:
@ -477,3 +480,6 @@ class Concatenate(Chain):
def forward(self, *args: Any) -> Tensor:
outputs = [module(*args) for module in self]
return cat([output for output in outputs if output is not None], dim=self.dim)
def _show_only_tag(self) -> bool:
return self.__class__ == Concatenate

View file

@ -8,11 +8,11 @@ class Conv2d(nn.Conv2d, WeightedModule):
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int],
stride: int | tuple[int, int] = 1,
padding: int | tuple[int, int] | str = 0,
stride: int | tuple[int, int] = (1, 1),
padding: int | tuple[int, int] | str = (0, 0),
groups: int = 1,
use_bias: bool = True,
dilation: int | tuple[int, int] = 1,
dilation: int | tuple[int, int] = (1, 1),
padding_mode: str = "zeros",
device: Device | str | None = None,
dtype: DType | None = None,
@ -30,6 +30,7 @@ class Conv2d(nn.Conv2d, WeightedModule):
device,
dtype,
)
self.use_bias = use_bias
class Conv1d(nn.Conv1d, WeightedModule):

View file

@ -1,5 +1,6 @@
from inspect import signature, Parameter
from pathlib import Path
from typing import Any, Generator, TypeVar
from typing import Any, Generator, TypeVar, TypedDict, cast
from torch import device as Device, dtype as DType
from torch.nn.modules.module import Module as TorchModule
@ -7,18 +8,20 @@ from torch.nn.modules.module import Module as TorchModule
from refiners.fluxion.utils import load_from_safetensors
from refiners.fluxion.context import Context, ContextProvider
from typing import Callable, TYPE_CHECKING
from typing import Callable, TYPE_CHECKING, Sequence
if TYPE_CHECKING:
from refiners.fluxion.layers.chain import Chain
T = TypeVar("T", bound="Module")
TContextModule = TypeVar("TContextModule", bound="ContextModule")
BasicType = str | float | int | bool
class Module(TorchModule):
_parameters: dict[str, Any]
_buffers: dict[str, Any]
_tag: str = ""
__getattr__: Callable[["Module", str], Any] # type: ignore
__setattr__: Callable[["Module", str, Any], None] # type: ignore
@ -37,6 +40,56 @@ class Module(TorchModule):
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
return super().to(device=device, dtype=dtype) # type: ignore
def __str__(self) -> str:
basic_attributes_str = ", ".join(
f"{key}={value}" for key, value in self.basic_attributes(init_attrs_only=True).items()
)
result = f"{self.__class__.__name__}({basic_attributes_str})"
return result
def __repr__(self) -> str:
tree = ModuleTree(module=self)
return repr(tree)
def pretty_print(self, depth: int = -1) -> None:
tree = ModuleTree(module=self)
print(tree.generate_tree_repr(tree.root, is_root=True, depth=depth))
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]:
"""Return a dictionary of basic attributes of the module.
Basic attributes are public attributes made of basic types (int, float, str, bool) or a sequence of basic types.
"""
sig = signature(obj=self.__init__)
init_params = set(sig.parameters.keys()) - {"self"}
default_values = {k: v.default for k, v in sig.parameters.items() if v.default is not Parameter.empty}
def is_basic_attribute(key: str, value: Any) -> bool:
if key.startswith("_"):
return False
if isinstance(value, BasicType):
return True
if isinstance(value, Sequence) and all(isinstance(y, BasicType) for y in cast(Sequence[Any], value)):
return True
return False
return {
key: str(object=value)
for key, value in self.__dict__.items()
if is_basic_attribute(key=key, value=value)
and (not init_attrs_only or (key in init_params and value != default_values.get(key)))
}
def _show_only_tag(self) -> bool:
"""Whether to show only the tag when printing the module.
This is useful to distinguish between Chain subclasses that override their forward from one another.
"""
return False
class ContextModule(Module):
# we store parent into a one element list to avoid pytorch thinking it's a submodule
@ -100,3 +153,73 @@ class WeightedModule(Module):
@property
def dtype(self) -> DType:
return self.weight.dtype
class TreeNode(TypedDict):
value: str
children: list["TreeNode"]
class ModuleTree:
def __init__(self, module: Module) -> None:
self.root: TreeNode = self._module_to_tree(module=module)
self._fold_successive_identical(node=self.root)
def __str__(self) -> str:
return f"{self.__class__.__name__}(root={self.root['value']})"
def __repr__(self) -> str:
return self.generate_tree_repr(node=self.root, is_root=True, depth=7)
def generate_tree_repr(
self, node: TreeNode, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
) -> str:
if depth == 0:
return ""
if depth > 0:
depth -= 1
tree_icon: str = "" if is_root else ("└── " if is_last else "├── ")
lines = [f"{prefix}{tree_icon}{node['value']}"]
new_prefix: str = " " if is_last else ""
for i, child in enumerate(iterable=node["children"]):
lines.append(
self.generate_tree_repr(
node=child,
prefix=prefix + new_prefix,
is_last=i == len(node["children"]) - 1,
is_root=False,
depth=depth,
)
)
return "\n".join(filter(bool, lines))
def _module_to_tree(self, module: Module) -> TreeNode:
match (module._tag, module._show_only_tag()): # pyright: ignore[reportPrivateUsage]
case ("", False):
value = str(object=module)
case (_, True):
value = f"({module._tag})" # pyright: ignore[reportPrivateUsage]
case (_, False):
value = f"({module._tag}) {module}" # pyright: ignore[reportPrivateUsage]
node: TreeNode = {"value": value, "children": []}
for child in module.children():
node["children"].append(self._module_to_tree(module=child)) # type: ignore
return node
def _fold_successive_identical(self, node: TreeNode) -> None:
i = 0
while i < len(node["children"]):
j = i
while j < len(node["children"]) and node["children"][i] == node["children"][j]:
j += 1
count = j - i
if count > 1:
node["children"][i]["value"] += f" (x{count})"
del node["children"][i + 1 : j]
self._fold_successive_identical(node=node["children"][i])
i += 1