mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
Add better tree representation for fluxion Module
This commit is contained in:
parent
d9a461e9b5
commit
cf43cb191f
|
@ -19,10 +19,6 @@ class View(Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.view(*self.shape)
|
return x.view(*self.shape)
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
shape_repr = ", ".join([repr(s) for s in self.shape])
|
|
||||||
return f"{self.__class__.__name__}({shape_repr})"
|
|
||||||
|
|
||||||
|
|
||||||
class Flatten(Module):
|
class Flatten(Module):
|
||||||
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None:
|
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None:
|
||||||
|
@ -33,9 +29,6 @@ class Flatten(Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.flatten(self.start_dim, self.end_dim)
|
return x.flatten(self.start_dim, self.end_dim)
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.__class__.__name__}(start_dim={repr(self.start_dim)}, end_dim={repr(self.end_dim)})"
|
|
||||||
|
|
||||||
|
|
||||||
class Unflatten(Module):
|
class Unflatten(Module):
|
||||||
def __init__(self, dim: int) -> None:
|
def __init__(self, dim: int) -> None:
|
||||||
|
@ -45,9 +38,6 @@ class Unflatten(Module):
|
||||||
def forward(self, x: Tensor, sizes: Size) -> Tensor:
|
def forward(self, x: Tensor, sizes: Size) -> Tensor:
|
||||||
return x.unflatten(self.dim, sizes) # type: ignore
|
return x.unflatten(self.dim, sizes) # type: ignore
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
|
|
||||||
|
|
||||||
|
|
||||||
class Reshape(Module):
|
class Reshape(Module):
|
||||||
"""
|
"""
|
||||||
|
@ -62,10 +52,6 @@ class Reshape(Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.reshape(x.shape[0], *self.shape)
|
return x.reshape(x.shape[0], *self.shape)
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
shape_repr = ", ".join([repr(s) for s in self.shape])
|
|
||||||
return f"{self.__class__.__name__}({shape_repr})"
|
|
||||||
|
|
||||||
|
|
||||||
class Transpose(Module):
|
class Transpose(Module):
|
||||||
def __init__(self, dim0: int, dim1: int) -> None:
|
def __init__(self, dim0: int, dim1: int) -> None:
|
||||||
|
@ -76,9 +62,6 @@ class Transpose(Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.transpose(self.dim0, self.dim1)
|
return x.transpose(self.dim0, self.dim1)
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.__class__.__name__}(dim0={repr(self.dim0)}, dim1={repr(self.dim1)})"
|
|
||||||
|
|
||||||
|
|
||||||
class Permute(Module):
|
class Permute(Module):
|
||||||
def __init__(self, *dims: int) -> None:
|
def __init__(self, *dims: int) -> None:
|
||||||
|
@ -88,10 +71,6 @@ class Permute(Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.permute(*self.dims)
|
return x.permute(*self.dims)
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
dims_repr = ", ".join([repr(d) for d in self.dims])
|
|
||||||
return f"{self.__class__.__name__}({dims_repr})"
|
|
||||||
|
|
||||||
|
|
||||||
class Slicing(Module):
|
class Slicing(Module):
|
||||||
def __init__(self, dim: int, start: int, length: int) -> None:
|
def __init__(self, dim: int, start: int, length: int) -> None:
|
||||||
|
@ -103,9 +82,6 @@ class Slicing(Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.narrow(self.dim, self.start, self.length)
|
return x.narrow(self.dim, self.start, self.length)
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.__class__.__name__}(dim={repr(self.dim)}, start={repr(self.start)}, length={repr(self.length)})"
|
|
||||||
|
|
||||||
|
|
||||||
class Squeeze(Module):
|
class Squeeze(Module):
|
||||||
def __init__(self, dim: int) -> None:
|
def __init__(self, dim: int) -> None:
|
||||||
|
@ -115,9 +91,6 @@ class Squeeze(Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.squeeze(self.dim)
|
return x.squeeze(self.dim)
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
|
|
||||||
|
|
||||||
|
|
||||||
class Unsqueeze(Module):
|
class Unsqueeze(Module):
|
||||||
def __init__(self, dim: int) -> None:
|
def __init__(self, dim: int) -> None:
|
||||||
|
@ -127,9 +100,6 @@ class Unsqueeze(Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.unsqueeze(self.dim)
|
return x.unsqueeze(self.dim)
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{self.__class__.__name__}(dim={repr(self.dim)})"
|
|
||||||
|
|
||||||
|
|
||||||
class Parameter(WeightedModule):
|
class Parameter(WeightedModule):
|
||||||
"""
|
"""
|
||||||
|
@ -138,6 +108,7 @@ class Parameter(WeightedModule):
|
||||||
|
|
||||||
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.dims = dims
|
||||||
self.register_parameter("parameter", TorchParameter(randn(*dims, device=device, dtype=dtype)))
|
self.register_parameter("parameter", TorchParameter(randn(*dims, device=device, dtype=dtype)))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -151,10 +122,6 @@ class Parameter(WeightedModule):
|
||||||
def forward(self, _: Tensor) -> Tensor:
|
def forward(self, _: Tensor) -> Tensor:
|
||||||
return self.parameter
|
return self.parameter
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
dims_repr = ", ".join([repr(d) for d in list(self.parameter.shape)])
|
|
||||||
return f"{self.__class__.__name__}({dims_repr}, device={repr(self.device)})"
|
|
||||||
|
|
||||||
|
|
||||||
class Buffer(WeightedModule):
|
class Buffer(WeightedModule):
|
||||||
"""
|
"""
|
||||||
|
@ -165,6 +132,7 @@ class Buffer(WeightedModule):
|
||||||
|
|
||||||
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.dims = dims
|
||||||
self.register_buffer("buffer", randn(*dims, device=device, dtype=dtype))
|
self.register_buffer("buffer", randn(*dims, device=device, dtype=dtype))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -177,7 +145,3 @@ class Buffer(WeightedModule):
|
||||||
|
|
||||||
def forward(self, _: Tensor) -> Tensor:
|
def forward(self, _: Tensor) -> Tensor:
|
||||||
return self.buffer
|
return self.buffer
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
dims_repr = ", ".join([repr(d) for d in list(self.buffer.shape)])
|
|
||||||
return f"{self.__class__.__name__}({dims_repr}, device={repr(self.device)})"
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ class Lambda(Module):
|
||||||
def forward(self, *args: Any) -> Any:
|
def forward(self, *args: Any) -> Any:
|
||||||
return self.func(*args)
|
return self.func(*args)
|
||||||
|
|
||||||
def __repr__(self):
|
def __str__(self) -> str:
|
||||||
func_name = getattr(self.func, "__name__", "partial_function")
|
func_name = getattr(self.func, "__name__", "partial_function")
|
||||||
return f"Lambda({func_name}{str(inspect.signature(self.func))})"
|
return f"Lambda({func_name}{str(inspect.signature(self.func))})"
|
||||||
|
|
||||||
|
@ -115,6 +115,7 @@ def structural_copy(m: T) -> T:
|
||||||
class Chain(ContextModule):
|
class Chain(ContextModule):
|
||||||
_modules: dict[str, Module]
|
_modules: dict[str, Module]
|
||||||
_provider: ContextProvider
|
_provider: ContextProvider
|
||||||
|
_tag = "CHAIN"
|
||||||
|
|
||||||
def __init__(self, *args: Module | Iterable[Module]) -> None:
|
def __init__(self, *args: Module | Iterable[Module]) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -235,28 +236,6 @@ class Chain(ContextModule):
|
||||||
def __iter__(self) -> Iterator[Module]:
|
def __iter__(self) -> Iterator[Module]:
|
||||||
return iter(self._modules.values())
|
return iter(self._modules.values())
|
||||||
|
|
||||||
def _pretty_print(self, num_tab: int = 0, layer_name: str | None = None) -> str:
|
|
||||||
layer_name = self.__class__.__name__ if layer_name is None else layer_name
|
|
||||||
pretty_print = f"{layer_name}:\n"
|
|
||||||
tab = " " * (num_tab + 4)
|
|
||||||
module_strings: list[str] = []
|
|
||||||
for i, (name, module) in enumerate(self._modules.items()):
|
|
||||||
ident = ("└+" if isinstance(self, Sum) else "└─") if i == 0 else " "
|
|
||||||
module_str = (
|
|
||||||
module
|
|
||||||
if not isinstance(module, Chain)
|
|
||||||
else (module._pretty_print(len(tab), name) if num_tab < 12 else f"{name}(...)")
|
|
||||||
)
|
|
||||||
module_strings.append(f"{tab}{ident} {module_str}")
|
|
||||||
pretty_print += "\n".join(module_strings)
|
|
||||||
return pretty_print
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return self._pretty_print()
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"<{self.__class__.__name__} at {hex(id(self))}>"
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self._modules)
|
return len(self._modules)
|
||||||
|
|
||||||
|
@ -418,25 +397,45 @@ class Chain(ContextModule):
|
||||||
|
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
|
def _show_only_tag(self) -> bool:
|
||||||
|
return self.__class__ == Chain
|
||||||
|
|
||||||
|
|
||||||
class Parallel(Chain):
|
class Parallel(Chain):
|
||||||
|
_tag = "PAR"
|
||||||
|
|
||||||
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||||
return tuple([self.call_layer(module, name, *args) for name, module in self._modules.items()])
|
return tuple([self.call_layer(module, name, *args) for name, module in self._modules.items()])
|
||||||
|
|
||||||
|
def _show_only_tag(self) -> bool:
|
||||||
|
return self.__class__ == Parallel
|
||||||
|
|
||||||
|
|
||||||
class Distribute(Chain):
|
class Distribute(Chain):
|
||||||
|
_tag = "DISTR"
|
||||||
|
|
||||||
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||||
assert len(args) == len(self._modules), "Number of positional arguments must match number of sub-modules."
|
assert len(args) == len(self._modules), "Number of positional arguments must match number of sub-modules."
|
||||||
return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
|
return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
|
||||||
|
|
||||||
|
def _show_only_tag(self) -> bool:
|
||||||
|
return self.__class__ == Distribute
|
||||||
|
|
||||||
|
|
||||||
class Passthrough(Chain):
|
class Passthrough(Chain):
|
||||||
|
_tag = "PASS"
|
||||||
|
|
||||||
def forward(self, *inputs: Any) -> Any:
|
def forward(self, *inputs: Any) -> Any:
|
||||||
super().forward(*inputs)
|
super().forward(*inputs)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
def _show_only_tag(self) -> bool:
|
||||||
|
return self.__class__ == Passthrough
|
||||||
|
|
||||||
|
|
||||||
class Sum(Chain):
|
class Sum(Chain):
|
||||||
|
_tag = "SUM"
|
||||||
|
|
||||||
def forward(self, *inputs: Any) -> Any:
|
def forward(self, *inputs: Any) -> Any:
|
||||||
output = None
|
output = None
|
||||||
for layer in self:
|
for layer in self:
|
||||||
|
@ -446,6 +445,9 @@ class Sum(Chain):
|
||||||
output = layer_output if output is None else output + layer_output
|
output = layer_output if output is None else output + layer_output
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _show_only_tag(self) -> bool:
|
||||||
|
return self.__class__ == Sum
|
||||||
|
|
||||||
|
|
||||||
class Residual(Sum):
|
class Residual(Sum):
|
||||||
def __init__(self, *modules: Module) -> None:
|
def __init__(self, *modules: Module) -> None:
|
||||||
|
@ -468,6 +470,7 @@ class Breakpoint(ContextModule):
|
||||||
|
|
||||||
|
|
||||||
class Concatenate(Chain):
|
class Concatenate(Chain):
|
||||||
|
_tag = "CAT"
|
||||||
structural_attrs = ["dim"]
|
structural_attrs = ["dim"]
|
||||||
|
|
||||||
def __init__(self, *modules: Module, dim: int = 0) -> None:
|
def __init__(self, *modules: Module, dim: int = 0) -> None:
|
||||||
|
@ -477,3 +480,6 @@ class Concatenate(Chain):
|
||||||
def forward(self, *args: Any) -> Tensor:
|
def forward(self, *args: Any) -> Tensor:
|
||||||
outputs = [module(*args) for module in self]
|
outputs = [module(*args) for module in self]
|
||||||
return cat([output for output in outputs if output is not None], dim=self.dim)
|
return cat([output for output in outputs if output is not None], dim=self.dim)
|
||||||
|
|
||||||
|
def _show_only_tag(self) -> bool:
|
||||||
|
return self.__class__ == Concatenate
|
||||||
|
|
|
@ -8,11 +8,11 @@ class Conv2d(nn.Conv2d, WeightedModule):
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
kernel_size: int | tuple[int, int],
|
kernel_size: int | tuple[int, int],
|
||||||
stride: int | tuple[int, int] = 1,
|
stride: int | tuple[int, int] = (1, 1),
|
||||||
padding: int | tuple[int, int] | str = 0,
|
padding: int | tuple[int, int] | str = (0, 0),
|
||||||
groups: int = 1,
|
groups: int = 1,
|
||||||
use_bias: bool = True,
|
use_bias: bool = True,
|
||||||
dilation: int | tuple[int, int] = 1,
|
dilation: int | tuple[int, int] = (1, 1),
|
||||||
padding_mode: str = "zeros",
|
padding_mode: str = "zeros",
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
|
@ -30,6 +30,7 @@ class Conv2d(nn.Conv2d, WeightedModule):
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
)
|
)
|
||||||
|
self.use_bias = use_bias
|
||||||
|
|
||||||
|
|
||||||
class Conv1d(nn.Conv1d, WeightedModule):
|
class Conv1d(nn.Conv1d, WeightedModule):
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
|
from inspect import signature, Parameter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Generator, TypeVar
|
from typing import Any, Generator, TypeVar, TypedDict, cast
|
||||||
|
|
||||||
from torch import device as Device, dtype as DType
|
from torch import device as Device, dtype as DType
|
||||||
from torch.nn.modules.module import Module as TorchModule
|
from torch.nn.modules.module import Module as TorchModule
|
||||||
|
@ -7,18 +8,20 @@ from torch.nn.modules.module import Module as TorchModule
|
||||||
from refiners.fluxion.utils import load_from_safetensors
|
from refiners.fluxion.utils import load_from_safetensors
|
||||||
from refiners.fluxion.context import Context, ContextProvider
|
from refiners.fluxion.context import Context, ContextProvider
|
||||||
|
|
||||||
from typing import Callable, TYPE_CHECKING
|
from typing import Callable, TYPE_CHECKING, Sequence
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from refiners.fluxion.layers.chain import Chain
|
from refiners.fluxion.layers.chain import Chain
|
||||||
|
|
||||||
T = TypeVar("T", bound="Module")
|
T = TypeVar("T", bound="Module")
|
||||||
TContextModule = TypeVar("TContextModule", bound="ContextModule")
|
TContextModule = TypeVar("TContextModule", bound="ContextModule")
|
||||||
|
BasicType = str | float | int | bool
|
||||||
|
|
||||||
|
|
||||||
class Module(TorchModule):
|
class Module(TorchModule):
|
||||||
_parameters: dict[str, Any]
|
_parameters: dict[str, Any]
|
||||||
_buffers: dict[str, Any]
|
_buffers: dict[str, Any]
|
||||||
|
_tag: str = ""
|
||||||
|
|
||||||
__getattr__: Callable[["Module", str], Any] # type: ignore
|
__getattr__: Callable[["Module", str], Any] # type: ignore
|
||||||
__setattr__: Callable[["Module", str, Any], None] # type: ignore
|
__setattr__: Callable[["Module", str, Any], None] # type: ignore
|
||||||
|
@ -37,6 +40,56 @@ class Module(TorchModule):
|
||||||
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
|
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
|
||||||
return super().to(device=device, dtype=dtype) # type: ignore
|
return super().to(device=device, dtype=dtype) # type: ignore
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
basic_attributes_str = ", ".join(
|
||||||
|
f"{key}={value}" for key, value in self.basic_attributes(init_attrs_only=True).items()
|
||||||
|
)
|
||||||
|
result = f"{self.__class__.__name__}({basic_attributes_str})"
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
tree = ModuleTree(module=self)
|
||||||
|
return repr(tree)
|
||||||
|
|
||||||
|
def pretty_print(self, depth: int = -1) -> None:
|
||||||
|
tree = ModuleTree(module=self)
|
||||||
|
print(tree.generate_tree_repr(tree.root, is_root=True, depth=depth))
|
||||||
|
|
||||||
|
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]:
|
||||||
|
"""Return a dictionary of basic attributes of the module.
|
||||||
|
|
||||||
|
Basic attributes are public attributes made of basic types (int, float, str, bool) or a sequence of basic types.
|
||||||
|
"""
|
||||||
|
sig = signature(obj=self.__init__)
|
||||||
|
init_params = set(sig.parameters.keys()) - {"self"}
|
||||||
|
default_values = {k: v.default for k, v in sig.parameters.items() if v.default is not Parameter.empty}
|
||||||
|
|
||||||
|
def is_basic_attribute(key: str, value: Any) -> bool:
|
||||||
|
if key.startswith("_"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(value, BasicType):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if isinstance(value, Sequence) and all(isinstance(y, BasicType) for y in cast(Sequence[Any], value)):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
return {
|
||||||
|
key: str(object=value)
|
||||||
|
for key, value in self.__dict__.items()
|
||||||
|
if is_basic_attribute(key=key, value=value)
|
||||||
|
and (not init_attrs_only or (key in init_params and value != default_values.get(key)))
|
||||||
|
}
|
||||||
|
|
||||||
|
def _show_only_tag(self) -> bool:
|
||||||
|
"""Whether to show only the tag when printing the module.
|
||||||
|
|
||||||
|
This is useful to distinguish between Chain subclasses that override their forward from one another.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ContextModule(Module):
|
class ContextModule(Module):
|
||||||
# we store parent into a one element list to avoid pytorch thinking it's a submodule
|
# we store parent into a one element list to avoid pytorch thinking it's a submodule
|
||||||
|
@ -100,3 +153,73 @@ class WeightedModule(Module):
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> DType:
|
def dtype(self) -> DType:
|
||||||
return self.weight.dtype
|
return self.weight.dtype
|
||||||
|
|
||||||
|
|
||||||
|
class TreeNode(TypedDict):
|
||||||
|
value: str
|
||||||
|
children: list["TreeNode"]
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTree:
|
||||||
|
def __init__(self, module: Module) -> None:
|
||||||
|
self.root: TreeNode = self._module_to_tree(module=module)
|
||||||
|
self._fold_successive_identical(node=self.root)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(root={self.root['value']})"
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self.generate_tree_repr(node=self.root, is_root=True, depth=7)
|
||||||
|
|
||||||
|
def generate_tree_repr(
|
||||||
|
self, node: TreeNode, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
|
||||||
|
) -> str:
|
||||||
|
if depth == 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if depth > 0:
|
||||||
|
depth -= 1
|
||||||
|
|
||||||
|
tree_icon: str = "" if is_root else ("└── " if is_last else "├── ")
|
||||||
|
lines = [f"{prefix}{tree_icon}{node['value']}"]
|
||||||
|
new_prefix: str = " " if is_last else "│ "
|
||||||
|
|
||||||
|
for i, child in enumerate(iterable=node["children"]):
|
||||||
|
lines.append(
|
||||||
|
self.generate_tree_repr(
|
||||||
|
node=child,
|
||||||
|
prefix=prefix + new_prefix,
|
||||||
|
is_last=i == len(node["children"]) - 1,
|
||||||
|
is_root=False,
|
||||||
|
depth=depth,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(filter(bool, lines))
|
||||||
|
|
||||||
|
def _module_to_tree(self, module: Module) -> TreeNode:
|
||||||
|
match (module._tag, module._show_only_tag()): # pyright: ignore[reportPrivateUsage]
|
||||||
|
case ("", False):
|
||||||
|
value = str(object=module)
|
||||||
|
case (_, True):
|
||||||
|
value = f"({module._tag})" # pyright: ignore[reportPrivateUsage]
|
||||||
|
case (_, False):
|
||||||
|
value = f"({module._tag}) {module}" # pyright: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
|
node: TreeNode = {"value": value, "children": []}
|
||||||
|
for child in module.children():
|
||||||
|
node["children"].append(self._module_to_tree(module=child)) # type: ignore
|
||||||
|
return node
|
||||||
|
|
||||||
|
def _fold_successive_identical(self, node: TreeNode) -> None:
|
||||||
|
i = 0
|
||||||
|
while i < len(node["children"]):
|
||||||
|
j = i
|
||||||
|
while j < len(node["children"]) and node["children"][i] == node["children"][j]:
|
||||||
|
j += 1
|
||||||
|
count = j - i
|
||||||
|
if count > 1:
|
||||||
|
node["children"][i]["value"] += f" (x{count})"
|
||||||
|
del node["children"][i + 1 : j]
|
||||||
|
self._fold_successive_identical(node=node["children"][i])
|
||||||
|
i += 1
|
||||||
|
|
Loading…
Reference in a new issue