From dc1fc239aa3b6db39abce0713e72230667d2f811 Mon Sep 17 00:00:00 2001 From: Benjamin Trom Date: Thu, 14 Sep 2023 16:47:27 +0200 Subject: [PATCH] show an ellipsis when chain has been shortened because of depth and count siblings with same class name --- src/refiners/fluxion/layers/module.py | 46 +++++++++++++++++++-------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/src/refiners/fluxion/layers/module.py b/src/refiners/fluxion/layers/module.py index 3e5cf1a..88c77b6 100644 --- a/src/refiners/fluxion/layers/module.py +++ b/src/refiners/fluxion/layers/module.py @@ -1,8 +1,9 @@ +from collections import defaultdict from inspect import signature, Parameter import sys from pathlib import Path from types import ModuleType -from typing import Any, Generator, TypeVar, TypedDict, cast +from typing import Any, DefaultDict, Generator, TypeVar, TypedDict, cast from torch import device as Device, dtype as DType from torch.nn.modules.module import Module as TorchModule @@ -165,6 +166,7 @@ class WeightedModule(Module): class TreeNode(TypedDict): value: str + class_name: str children: list["TreeNode"] @@ -182,39 +184,55 @@ class ModuleTree: def generate_tree_repr( self, node: TreeNode, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1 ) -> str: - if depth == 0: - return "" + if depth == 0 and node["children"]: + return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..." if depth > 0: depth -= 1 tree_icon: str = "" if is_root else ("└── " if is_last else "├── ") + counts: DefaultDict[str, int] = defaultdict(int) + + for child in node["children"]: + counts[child["class_name"]] += 1 + + instance_counts: DefaultDict[str, int] = defaultdict(int) lines = [f"{prefix}{tree_icon}{node['value']}"] new_prefix: str = " " if is_last else "│ " for i, child in enumerate(iterable=node["children"]): - lines.append( - self.generate_tree_repr( - node=child, - prefix=prefix + new_prefix, - is_last=i == len(node["children"]) - 1, - is_root=False, - depth=depth, - ) + instance_counts[child["class_name"]] += 1 + + if counts[child["class_name"]] > 1: + child_value = f"{child['value']} #{instance_counts[child['class_name']]}" + else: + child_value = child["value"] + + child_str = self.generate_tree_repr( + {"value": child_value, "class_name": child["class_name"], "children": child["children"]}, + prefix=prefix + new_prefix, + is_last=i == len(node["children"]) - 1, + is_root=False, + depth=depth, ) - return "\n".join(filter(bool, lines)) + if child_str: + lines.append(child_str) + + return "\n".join(lines) def _module_to_tree(self, module: Module) -> TreeNode: match (module._tag, module._show_only_tag()): # pyright: ignore[reportPrivateUsage] case ("", False): - value = str(object=module) + value = str(module) case (_, True): value = f"({module._tag})" # pyright: ignore[reportPrivateUsage] case (_, False): value = f"({module._tag}) {module}" # pyright: ignore[reportPrivateUsage] - node: TreeNode = {"value": value, "children": []} + class_name = module.__class__.__name__ + + node: TreeNode = {"value": value, "class_name": class_name, "children": []} for child in module.children(): node["children"].append(self._module_to_tree(module=child)) # type: ignore return node