mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
show an ellipsis when chain has been shortened because of depth and count siblings with same class name
This commit is contained in:
parent
1cb798e8ae
commit
dc1fc239aa
|
@ -1,8 +1,9 @@
|
||||||
|
from collections import defaultdict
|
||||||
from inspect import signature, Parameter
|
from inspect import signature, Parameter
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Any, Generator, TypeVar, TypedDict, cast
|
from typing import Any, DefaultDict, Generator, TypeVar, TypedDict, cast
|
||||||
|
|
||||||
from torch import device as Device, dtype as DType
|
from torch import device as Device, dtype as DType
|
||||||
from torch.nn.modules.module import Module as TorchModule
|
from torch.nn.modules.module import Module as TorchModule
|
||||||
|
@ -165,6 +166,7 @@ class WeightedModule(Module):
|
||||||
|
|
||||||
class TreeNode(TypedDict):
|
class TreeNode(TypedDict):
|
||||||
value: str
|
value: str
|
||||||
|
class_name: str
|
||||||
children: list["TreeNode"]
|
children: list["TreeNode"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -182,39 +184,55 @@ class ModuleTree:
|
||||||
def generate_tree_repr(
|
def generate_tree_repr(
|
||||||
self, node: TreeNode, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
|
self, node: TreeNode, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
|
||||||
) -> str:
|
) -> str:
|
||||||
if depth == 0:
|
if depth == 0 and node["children"]:
|
||||||
return ""
|
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."
|
||||||
|
|
||||||
if depth > 0:
|
if depth > 0:
|
||||||
depth -= 1
|
depth -= 1
|
||||||
|
|
||||||
tree_icon: str = "" if is_root else ("└── " if is_last else "├── ")
|
tree_icon: str = "" if is_root else ("└── " if is_last else "├── ")
|
||||||
|
counts: DefaultDict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
|
for child in node["children"]:
|
||||||
|
counts[child["class_name"]] += 1
|
||||||
|
|
||||||
|
instance_counts: DefaultDict[str, int] = defaultdict(int)
|
||||||
lines = [f"{prefix}{tree_icon}{node['value']}"]
|
lines = [f"{prefix}{tree_icon}{node['value']}"]
|
||||||
new_prefix: str = " " if is_last else "│ "
|
new_prefix: str = " " if is_last else "│ "
|
||||||
|
|
||||||
for i, child in enumerate(iterable=node["children"]):
|
for i, child in enumerate(iterable=node["children"]):
|
||||||
lines.append(
|
instance_counts[child["class_name"]] += 1
|
||||||
self.generate_tree_repr(
|
|
||||||
node=child,
|
if counts[child["class_name"]] > 1:
|
||||||
prefix=prefix + new_prefix,
|
child_value = f"{child['value']} #{instance_counts[child['class_name']]}"
|
||||||
is_last=i == len(node["children"]) - 1,
|
else:
|
||||||
is_root=False,
|
child_value = child["value"]
|
||||||
depth=depth,
|
|
||||||
)
|
child_str = self.generate_tree_repr(
|
||||||
|
{"value": child_value, "class_name": child["class_name"], "children": child["children"]},
|
||||||
|
prefix=prefix + new_prefix,
|
||||||
|
is_last=i == len(node["children"]) - 1,
|
||||||
|
is_root=False,
|
||||||
|
depth=depth,
|
||||||
)
|
)
|
||||||
|
|
||||||
return "\n".join(filter(bool, lines))
|
if child_str:
|
||||||
|
lines.append(child_str)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
def _module_to_tree(self, module: Module) -> TreeNode:
|
def _module_to_tree(self, module: Module) -> TreeNode:
|
||||||
match (module._tag, module._show_only_tag()): # pyright: ignore[reportPrivateUsage]
|
match (module._tag, module._show_only_tag()): # pyright: ignore[reportPrivateUsage]
|
||||||
case ("", False):
|
case ("", False):
|
||||||
value = str(object=module)
|
value = str(module)
|
||||||
case (_, True):
|
case (_, True):
|
||||||
value = f"({module._tag})" # pyright: ignore[reportPrivateUsage]
|
value = f"({module._tag})" # pyright: ignore[reportPrivateUsage]
|
||||||
case (_, False):
|
case (_, False):
|
||||||
value = f"({module._tag}) {module}" # pyright: ignore[reportPrivateUsage]
|
value = f"({module._tag}) {module}" # pyright: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
node: TreeNode = {"value": value, "children": []}
|
class_name = module.__class__.__name__
|
||||||
|
|
||||||
|
node: TreeNode = {"value": value, "class_name": class_name, "children": []}
|
||||||
for child in module.children():
|
for child in module.children():
|
||||||
node["children"].append(self._module_to_tree(module=child)) # type: ignore
|
node["children"].append(self._module_to_tree(module=child)) # type: ignore
|
||||||
return node
|
return node
|
||||||
|
|
Loading…
Reference in a new issue