mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
Add better tree representation for fluxion Module
This commit is contained in:
parent
d9a461e9b5
commit
cf43cb191f
|
@ -19,10 +19,6 @@ class View(Module):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.view(*self.shape)
|
||||
|
||||
def __repr__(self):
|
||||
shape_repr = ", ".join([repr(s) for s in self.shape])
|
||||
return f"{self.__class__.__name__}({shape_repr})"
|
||||
|
||||
|
||||
class Flatten(Module):
|
||||
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None:
|
||||
|
@ -33,9 +29,6 @@ class Flatten(Module):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.flatten(self.start_dim, self.end_dim)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(start_dim={repr(self.start_dim)}, end_dim={repr(self.end_dim)})"
|
||||
|
||||
|
||||
class Unflatten(Module):
|
||||
def __init__(self, dim: int) -> None:
|
||||
|
@ -45,9 +38,6 @@ class Unflatten(Module):
|
|||
def forward(self, x: Tensor, sizes: Size) -> Tensor:
|
||||
return x.unflatten(self.dim, sizes) # type: ignore
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
|
||||
|
||||
|
||||
class Reshape(Module):
|
||||
"""
|
||||
|
@ -62,10 +52,6 @@ class Reshape(Module):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.reshape(x.shape[0], *self.shape)
|
||||
|
||||
def __repr__(self):
|
||||
shape_repr = ", ".join([repr(s) for s in self.shape])
|
||||
return f"{self.__class__.__name__}({shape_repr})"
|
||||
|
||||
|
||||
class Transpose(Module):
|
||||
def __init__(self, dim0: int, dim1: int) -> None:
|
||||
|
@ -76,9 +62,6 @@ class Transpose(Module):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.transpose(self.dim0, self.dim1)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(dim0={repr(self.dim0)}, dim1={repr(self.dim1)})"
|
||||
|
||||
|
||||
class Permute(Module):
|
||||
def __init__(self, *dims: int) -> None:
|
||||
|
@ -88,10 +71,6 @@ class Permute(Module):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.permute(*self.dims)
|
||||
|
||||
def __repr__(self):
|
||||
dims_repr = ", ".join([repr(d) for d in self.dims])
|
||||
return f"{self.__class__.__name__}({dims_repr})"
|
||||
|
||||
|
||||
class Slicing(Module):
|
||||
def __init__(self, dim: int, start: int, length: int) -> None:
|
||||
|
@ -103,9 +82,6 @@ class Slicing(Module):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.narrow(self.dim, self.start, self.length)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(dim={repr(self.dim)}, start={repr(self.start)}, length={repr(self.length)})"
|
||||
|
||||
|
||||
class Squeeze(Module):
|
||||
def __init__(self, dim: int) -> None:
|
||||
|
@ -115,9 +91,6 @@ class Squeeze(Module):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.squeeze(self.dim)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
|
||||
|
||||
|
||||
class Unsqueeze(Module):
|
||||
def __init__(self, dim: int) -> None:
|
||||
|
@ -127,9 +100,6 @@ class Unsqueeze(Module):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.unsqueeze(self.dim)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
|
||||
|
||||
|
||||
class Parameter(WeightedModule):
|
||||
"""
|
||||
|
@ -138,6 +108,7 @@ class Parameter(WeightedModule):
|
|||
|
||||
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.register_parameter("parameter", TorchParameter(randn(*dims, device=device, dtype=dtype)))
|
||||
|
||||
@property
|
||||
|
@ -151,10 +122,6 @@ class Parameter(WeightedModule):
|
|||
def forward(self, _: Tensor) -> Tensor:
|
||||
return self.parameter
|
||||
|
||||
def __repr__(self):
|
||||
dims_repr = ", ".join([repr(d) for d in list(self.parameter.shape)])
|
||||
return f"{self.__class__.__name__}({dims_repr}, device={repr(self.device)})"
|
||||
|
||||
|
||||
class Buffer(WeightedModule):
|
||||
"""
|
||||
|
@ -165,6 +132,7 @@ class Buffer(WeightedModule):
|
|||
|
||||
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.register_buffer("buffer", randn(*dims, device=device, dtype=dtype))
|
||||
|
||||
@property
|
||||
|
@ -177,7 +145,3 @@ class Buffer(WeightedModule):
|
|||
|
||||
def forward(self, _: Tensor) -> Tensor:
|
||||
return self.buffer
|
||||
|
||||
def __repr__(self):
|
||||
dims_repr = ", ".join([repr(d) for d in list(self.buffer.shape)])
|
||||
return f"{self.__class__.__name__}({dims_repr}, device={repr(self.device)})"
|
||||
|
|
|
@ -20,7 +20,7 @@ class Lambda(Module):
|
|||
def forward(self, *args: Any) -> Any:
|
||||
return self.func(*args)
|
||||
|
||||
def __repr__(self):
|
||||
def __str__(self) -> str:
|
||||
func_name = getattr(self.func, "__name__", "partial_function")
|
||||
return f"Lambda({func_name}{str(inspect.signature(self.func))})"
|
||||
|
||||
|
@ -115,6 +115,7 @@ def structural_copy(m: T) -> T:
|
|||
class Chain(ContextModule):
|
||||
_modules: dict[str, Module]
|
||||
_provider: ContextProvider
|
||||
_tag = "CHAIN"
|
||||
|
||||
def __init__(self, *args: Module | Iterable[Module]) -> None:
|
||||
super().__init__()
|
||||
|
@ -235,28 +236,6 @@ class Chain(ContextModule):
|
|||
def __iter__(self) -> Iterator[Module]:
|
||||
return iter(self._modules.values())
|
||||
|
||||
def _pretty_print(self, num_tab: int = 0, layer_name: str | None = None) -> str:
|
||||
layer_name = self.__class__.__name__ if layer_name is None else layer_name
|
||||
pretty_print = f"{layer_name}:\n"
|
||||
tab = " " * (num_tab + 4)
|
||||
module_strings: list[str] = []
|
||||
for i, (name, module) in enumerate(self._modules.items()):
|
||||
ident = ("└+" if isinstance(self, Sum) else "└─") if i == 0 else " "
|
||||
module_str = (
|
||||
module
|
||||
if not isinstance(module, Chain)
|
||||
else (module._pretty_print(len(tab), name) if num_tab < 12 else f"{name}(...)")
|
||||
)
|
||||
module_strings.append(f"{tab}{ident} {module_str}")
|
||||
pretty_print += "\n".join(module_strings)
|
||||
return pretty_print
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self._pretty_print()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"<{self.__class__.__name__} at {hex(id(self))}>"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._modules)
|
||||
|
||||
|
@ -418,25 +397,45 @@ class Chain(ContextModule):
|
|||
|
||||
return clone
|
||||
|
||||
def _show_only_tag(self) -> bool:
|
||||
return self.__class__ == Chain
|
||||
|
||||
|
||||
class Parallel(Chain):
|
||||
_tag = "PAR"
|
||||
|
||||
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||
return tuple([self.call_layer(module, name, *args) for name, module in self._modules.items()])
|
||||
|
||||
def _show_only_tag(self) -> bool:
|
||||
return self.__class__ == Parallel
|
||||
|
||||
|
||||
class Distribute(Chain):
|
||||
_tag = "DISTR"
|
||||
|
||||
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||
assert len(args) == len(self._modules), "Number of positional arguments must match number of sub-modules."
|
||||
return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
|
||||
|
||||
def _show_only_tag(self) -> bool:
|
||||
return self.__class__ == Distribute
|
||||
|
||||
|
||||
class Passthrough(Chain):
|
||||
_tag = "PASS"
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
super().forward(*inputs)
|
||||
return inputs
|
||||
|
||||
def _show_only_tag(self) -> bool:
|
||||
return self.__class__ == Passthrough
|
||||
|
||||
|
||||
class Sum(Chain):
|
||||
_tag = "SUM"
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
output = None
|
||||
for layer in self:
|
||||
|
@ -446,6 +445,9 @@ class Sum(Chain):
|
|||
output = layer_output if output is None else output + layer_output
|
||||
return output
|
||||
|
||||
def _show_only_tag(self) -> bool:
|
||||
return self.__class__ == Sum
|
||||
|
||||
|
||||
class Residual(Sum):
|
||||
def __init__(self, *modules: Module) -> None:
|
||||
|
@ -468,6 +470,7 @@ class Breakpoint(ContextModule):
|
|||
|
||||
|
||||
class Concatenate(Chain):
|
||||
_tag = "CAT"
|
||||
structural_attrs = ["dim"]
|
||||
|
||||
def __init__(self, *modules: Module, dim: int = 0) -> None:
|
||||
|
@ -477,3 +480,6 @@ class Concatenate(Chain):
|
|||
def forward(self, *args: Any) -> Tensor:
|
||||
outputs = [module(*args) for module in self]
|
||||
return cat([output for output in outputs if output is not None], dim=self.dim)
|
||||
|
||||
def _show_only_tag(self) -> bool:
|
||||
return self.__class__ == Concatenate
|
||||
|
|
|
@ -8,11 +8,11 @@ class Conv2d(nn.Conv2d, WeightedModule):
|
|||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int | tuple[int, int],
|
||||
stride: int | tuple[int, int] = 1,
|
||||
padding: int | tuple[int, int] | str = 0,
|
||||
stride: int | tuple[int, int] = (1, 1),
|
||||
padding: int | tuple[int, int] | str = (0, 0),
|
||||
groups: int = 1,
|
||||
use_bias: bool = True,
|
||||
dilation: int | tuple[int, int] = 1,
|
||||
dilation: int | tuple[int, int] = (1, 1),
|
||||
padding_mode: str = "zeros",
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
|
@ -30,6 +30,7 @@ class Conv2d(nn.Conv2d, WeightedModule):
|
|||
device,
|
||||
dtype,
|
||||
)
|
||||
self.use_bias = use_bias
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d, WeightedModule):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from inspect import signature, Parameter
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, TypeVar
|
||||
from typing import Any, Generator, TypeVar, TypedDict, cast
|
||||
|
||||
from torch import device as Device, dtype as DType
|
||||
from torch.nn.modules.module import Module as TorchModule
|
||||
|
@ -7,18 +8,20 @@ from torch.nn.modules.module import Module as TorchModule
|
|||
from refiners.fluxion.utils import load_from_safetensors
|
||||
from refiners.fluxion.context import Context, ContextProvider
|
||||
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from typing import Callable, TYPE_CHECKING, Sequence
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from refiners.fluxion.layers.chain import Chain
|
||||
|
||||
T = TypeVar("T", bound="Module")
|
||||
TContextModule = TypeVar("TContextModule", bound="ContextModule")
|
||||
BasicType = str | float | int | bool
|
||||
|
||||
|
||||
class Module(TorchModule):
|
||||
_parameters: dict[str, Any]
|
||||
_buffers: dict[str, Any]
|
||||
_tag: str = ""
|
||||
|
||||
__getattr__: Callable[["Module", str], Any] # type: ignore
|
||||
__setattr__: Callable[["Module", str, Any], None] # type: ignore
|
||||
|
@ -37,6 +40,56 @@ class Module(TorchModule):
|
|||
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
|
||||
return super().to(device=device, dtype=dtype) # type: ignore
|
||||
|
||||
def __str__(self) -> str:
|
||||
basic_attributes_str = ", ".join(
|
||||
f"{key}={value}" for key, value in self.basic_attributes(init_attrs_only=True).items()
|
||||
)
|
||||
result = f"{self.__class__.__name__}({basic_attributes_str})"
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
tree = ModuleTree(module=self)
|
||||
return repr(tree)
|
||||
|
||||
def pretty_print(self, depth: int = -1) -> None:
|
||||
tree = ModuleTree(module=self)
|
||||
print(tree.generate_tree_repr(tree.root, is_root=True, depth=depth))
|
||||
|
||||
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]:
|
||||
"""Return a dictionary of basic attributes of the module.
|
||||
|
||||
Basic attributes are public attributes made of basic types (int, float, str, bool) or a sequence of basic types.
|
||||
"""
|
||||
sig = signature(obj=self.__init__)
|
||||
init_params = set(sig.parameters.keys()) - {"self"}
|
||||
default_values = {k: v.default for k, v in sig.parameters.items() if v.default is not Parameter.empty}
|
||||
|
||||
def is_basic_attribute(key: str, value: Any) -> bool:
|
||||
if key.startswith("_"):
|
||||
return False
|
||||
|
||||
if isinstance(value, BasicType):
|
||||
return True
|
||||
|
||||
if isinstance(value, Sequence) and all(isinstance(y, BasicType) for y in cast(Sequence[Any], value)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return {
|
||||
key: str(object=value)
|
||||
for key, value in self.__dict__.items()
|
||||
if is_basic_attribute(key=key, value=value)
|
||||
and (not init_attrs_only or (key in init_params and value != default_values.get(key)))
|
||||
}
|
||||
|
||||
def _show_only_tag(self) -> bool:
|
||||
"""Whether to show only the tag when printing the module.
|
||||
|
||||
This is useful to distinguish between Chain subclasses that override their forward from one another.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class ContextModule(Module):
|
||||
# we store parent into a one element list to avoid pytorch thinking it's a submodule
|
||||
|
@ -100,3 +153,73 @@ class WeightedModule(Module):
|
|||
@property
|
||||
def dtype(self) -> DType:
|
||||
return self.weight.dtype
|
||||
|
||||
|
||||
class TreeNode(TypedDict):
|
||||
value: str
|
||||
children: list["TreeNode"]
|
||||
|
||||
|
||||
class ModuleTree:
|
||||
def __init__(self, module: Module) -> None:
|
||||
self.root: TreeNode = self._module_to_tree(module=module)
|
||||
self._fold_successive_identical(node=self.root)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}(root={self.root['value']})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.generate_tree_repr(node=self.root, is_root=True, depth=7)
|
||||
|
||||
def generate_tree_repr(
|
||||
self, node: TreeNode, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
|
||||
) -> str:
|
||||
if depth == 0:
|
||||
return ""
|
||||
|
||||
if depth > 0:
|
||||
depth -= 1
|
||||
|
||||
tree_icon: str = "" if is_root else ("└── " if is_last else "├── ")
|
||||
lines = [f"{prefix}{tree_icon}{node['value']}"]
|
||||
new_prefix: str = " " if is_last else "│ "
|
||||
|
||||
for i, child in enumerate(iterable=node["children"]):
|
||||
lines.append(
|
||||
self.generate_tree_repr(
|
||||
node=child,
|
||||
prefix=prefix + new_prefix,
|
||||
is_last=i == len(node["children"]) - 1,
|
||||
is_root=False,
|
||||
depth=depth,
|
||||
)
|
||||
)
|
||||
|
||||
return "\n".join(filter(bool, lines))
|
||||
|
||||
def _module_to_tree(self, module: Module) -> TreeNode:
|
||||
match (module._tag, module._show_only_tag()): # pyright: ignore[reportPrivateUsage]
|
||||
case ("", False):
|
||||
value = str(object=module)
|
||||
case (_, True):
|
||||
value = f"({module._tag})" # pyright: ignore[reportPrivateUsage]
|
||||
case (_, False):
|
||||
value = f"({module._tag}) {module}" # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
node: TreeNode = {"value": value, "children": []}
|
||||
for child in module.children():
|
||||
node["children"].append(self._module_to_tree(module=child)) # type: ignore
|
||||
return node
|
||||
|
||||
def _fold_successive_identical(self, node: TreeNode) -> None:
|
||||
i = 0
|
||||
while i < len(node["children"]):
|
||||
j = i
|
||||
while j < len(node["children"]) and node["children"][i] == node["children"][j]:
|
||||
j += 1
|
||||
count = j - i
|
||||
if count > 1:
|
||||
node["children"][i]["value"] += f" (x{count})"
|
||||
del node["children"][i + 1 : j]
|
||||
self._fold_successive_identical(node=node["children"][i])
|
||||
i += 1
|
||||
|
|
Loading…
Reference in a new issue