mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
show an ellipsis when chain has been shortened because of depth and count siblings with same class name
This commit is contained in:
parent
1cb798e8ae
commit
dc1fc239aa
|
@ -1,8 +1,9 @@
|
|||
from collections import defaultdict
|
||||
from inspect import signature, Parameter
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Generator, TypeVar, TypedDict, cast
|
||||
from typing import Any, DefaultDict, Generator, TypeVar, TypedDict, cast
|
||||
|
||||
from torch import device as Device, dtype as DType
|
||||
from torch.nn.modules.module import Module as TorchModule
|
||||
|
@ -165,6 +166,7 @@ class WeightedModule(Module):
|
|||
|
||||
class TreeNode(TypedDict):
|
||||
value: str
|
||||
class_name: str
|
||||
children: list["TreeNode"]
|
||||
|
||||
|
||||
|
@ -182,39 +184,55 @@ class ModuleTree:
|
|||
def generate_tree_repr(
|
||||
self, node: TreeNode, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
|
||||
) -> str:
|
||||
if depth == 0:
|
||||
return ""
|
||||
if depth == 0 and node["children"]:
|
||||
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."
|
||||
|
||||
if depth > 0:
|
||||
depth -= 1
|
||||
|
||||
tree_icon: str = "" if is_root else ("└── " if is_last else "├── ")
|
||||
counts: DefaultDict[str, int] = defaultdict(int)
|
||||
|
||||
for child in node["children"]:
|
||||
counts[child["class_name"]] += 1
|
||||
|
||||
instance_counts: DefaultDict[str, int] = defaultdict(int)
|
||||
lines = [f"{prefix}{tree_icon}{node['value']}"]
|
||||
new_prefix: str = " " if is_last else "│ "
|
||||
|
||||
for i, child in enumerate(iterable=node["children"]):
|
||||
lines.append(
|
||||
self.generate_tree_repr(
|
||||
node=child,
|
||||
instance_counts[child["class_name"]] += 1
|
||||
|
||||
if counts[child["class_name"]] > 1:
|
||||
child_value = f"{child['value']} #{instance_counts[child['class_name']]}"
|
||||
else:
|
||||
child_value = child["value"]
|
||||
|
||||
child_str = self.generate_tree_repr(
|
||||
{"value": child_value, "class_name": child["class_name"], "children": child["children"]},
|
||||
prefix=prefix + new_prefix,
|
||||
is_last=i == len(node["children"]) - 1,
|
||||
is_root=False,
|
||||
depth=depth,
|
||||
)
|
||||
)
|
||||
|
||||
return "\n".join(filter(bool, lines))
|
||||
if child_str:
|
||||
lines.append(child_str)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _module_to_tree(self, module: Module) -> TreeNode:
|
||||
match (module._tag, module._show_only_tag()): # pyright: ignore[reportPrivateUsage]
|
||||
case ("", False):
|
||||
value = str(object=module)
|
||||
value = str(module)
|
||||
case (_, True):
|
||||
value = f"({module._tag})" # pyright: ignore[reportPrivateUsage]
|
||||
case (_, False):
|
||||
value = f"({module._tag}) {module}" # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
node: TreeNode = {"value": value, "children": []}
|
||||
class_name = module.__class__.__name__
|
||||
|
||||
node: TreeNode = {"value": value, "class_name": class_name, "children": []}
|
||||
for child in module.children():
|
||||
node["children"].append(self._module_to_tree(module=child)) # type: ignore
|
||||
return node
|
||||
|
|
Loading…
Reference in a new issue