show an ellipsis when chain has been shortened because of depth and count siblings with same class name

This commit is contained in:
Benjamin Trom 2023-09-14 16:47:27 +02:00
parent 1cb798e8ae
commit dc1fc239aa

View file

@ -1,8 +1,9 @@
from collections import defaultdict
from inspect import signature, Parameter
import sys
from pathlib import Path
from types import ModuleType
from typing import Any, Generator, TypeVar, TypedDict, cast
from typing import Any, DefaultDict, Generator, TypeVar, TypedDict, cast
from torch import device as Device, dtype as DType
from torch.nn.modules.module import Module as TorchModule
@ -165,6 +166,7 @@ class WeightedModule(Module):
class TreeNode(TypedDict):
value: str
class_name: str
children: list["TreeNode"]
@ -182,39 +184,55 @@ class ModuleTree:
def generate_tree_repr(
self, node: TreeNode, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
) -> str:
if depth == 0:
return ""
if depth == 0 and node["children"]:
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."
if depth > 0:
depth -= 1
tree_icon: str = "" if is_root else ("└── " if is_last else "├── ")
counts: DefaultDict[str, int] = defaultdict(int)
for child in node["children"]:
counts[child["class_name"]] += 1
instance_counts: DefaultDict[str, int] = defaultdict(int)
lines = [f"{prefix}{tree_icon}{node['value']}"]
new_prefix: str = " " if is_last else ""
for i, child in enumerate(iterable=node["children"]):
lines.append(
self.generate_tree_repr(
node=child,
prefix=prefix + new_prefix,
is_last=i == len(node["children"]) - 1,
is_root=False,
depth=depth,
)
instance_counts[child["class_name"]] += 1
if counts[child["class_name"]] > 1:
child_value = f"{child['value']} #{instance_counts[child['class_name']]}"
else:
child_value = child["value"]
child_str = self.generate_tree_repr(
{"value": child_value, "class_name": child["class_name"], "children": child["children"]},
prefix=prefix + new_prefix,
is_last=i == len(node["children"]) - 1,
is_root=False,
depth=depth,
)
return "\n".join(filter(bool, lines))
if child_str:
lines.append(child_str)
return "\n".join(lines)
def _module_to_tree(self, module: Module) -> TreeNode:
match (module._tag, module._show_only_tag()): # pyright: ignore[reportPrivateUsage]
case ("", False):
value = str(object=module)
value = str(module)
case (_, True):
value = f"({module._tag})" # pyright: ignore[reportPrivateUsage]
case (_, False):
value = f"({module._tag}) {module}" # pyright: ignore[reportPrivateUsage]
node: TreeNode = {"value": value, "children": []}
class_name = module.__class__.__name__
node: TreeNode = {"value": value, "class_name": class_name, "children": []}
for child in module.children():
node["children"].append(self._module_to_tree(module=child)) # type: ignore
return node