improve debug print for chains

This commit is contained in:
Benjamin Trom 2023-10-10 15:02:52 +02:00
parent a663375dc7
commit 0024191c58
4 changed files with 148 additions and 24 deletions

View file

@ -1,10 +1,15 @@
from collections import defaultdict
import inspect
import re
import sys
import traceback
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
import torch
from torch import Tensor, cat, device as Device, dtype as DType
from refiners.fluxion.layers.basics import Identity
from refiners.fluxion.layers.module import Module, ContextModule, WeightedModule
from refiners.fluxion.layers.module import Module, ContextModule, ModuleTree, WeightedModule
from refiners.fluxion.context import Contexts, ContextProvider
from refiners.fluxion.utils import summarize_tensor
T = TypeVar("T", bound=Module)
@ -109,6 +114,13 @@ def structural_copy(m: T) -> T:
return m.structural_copy() if isinstance(m, ContextModule) else m
class ChainError(RuntimeError):
"""Exception raised when an error occurs during the execution of a Chain."""
def __init__(self, message: str, /) -> None:
super().__init__(message)
class Chain(ContextModule):
_modules: dict[str, Module]
_provider: ContextProvider
@ -173,34 +185,98 @@ class Chain(ContextModule):
self._provider.set_context(context, value)
self._register_provider()
def debug_repr(self, layer_name: str = "") -> str:
lines: list[str] = []
tab = " "
tab_length = 0
for i, parent in enumerate(self.get_parents()[::-1]):
lines.append(f"{tab*tab_length}{'└─ ' if i else ''}{parent.__class__.__name__}")
tab_length += 1
def _show_error_in_tree(self, name: str, /, max_lines: int = 20) -> str:
tree = ModuleTree(module=self)
classname_counter: dict[str, int] = defaultdict(int)
first_ancestor = self.get_parents()[-1] if self.get_parents() else self
lines.append(f"{tab*tab_length}└─ {self.__class__.__name__}")
def find_state_dict_key(module: Module, /) -> str | None:
for key, layer in module.named_modules():
if layer == self:
return ".".join((key, name))
return None
for name, _ in self._modules.items():
error_arrow = "⚠️" if name == layer_name else ""
lines.append(f"{tab*tab_length} | {name} {error_arrow}")
for child in tree:
classname, count = name.rsplit(sep="_", maxsplit=1) if "_" in name else (name, "1")
if child["class_name"] == classname:
classname_counter[classname] += 1
if classname_counter[classname] == int(count):
state_dict_key = find_state_dict_key(first_ancestor)
child["value"] = f">>> {child['value']} | {state_dict_key}"
break
return "\n".join(lines)
tree_repr = tree._generate_tree_repr(tree.root, depth=3) # type: ignore[reportPrivateUsage]
def call_layer(self, layer: Module, layer_name: str, *args: Any):
lines = tree_repr.split(sep="\n")
error_line_idx = next((idx for idx, line in enumerate(iterable=lines) if line.startswith(">>>")), 0)
return ModuleTree.shorten_tree_repr(tree_repr, line_index=error_line_idx, max_lines=max_lines)
@staticmethod
def _pretty_print_args(*args: Any) -> str:
"""
Flatten nested tuples and print tensors with their shape and other informations.
"""
def _flatten_tuple(t: Tensor | tuple[Any, ...], /) -> list[Any]:
if isinstance(t, tuple):
return [item for subtuple in t for item in _flatten_tuple(subtuple)]
else:
return [t]
flat_args = _flatten_tuple(args)
return "\n".join(
[
f"{idx}: {summarize_tensor(arg) if isinstance(arg, Tensor) else arg}"
for idx, arg in enumerate(iterable=flat_args)
]
)
def _filter_traceback(self, *frames: traceback.FrameSummary) -> list[traceback.FrameSummary]:
patterns_to_exclude = [
(r"torch/nn/modules/", r"^_call_impl$"),
(r"torch/nn/functional\.py", r""),
(r"refiners/fluxion/layers/", r"^_call_layer$"),
(r"refiners/fluxion/layers/", r"^forward$"),
(r"refiners/fluxion/layers/chain\.py", r""),
(r"", r"^_"),
]
def should_exclude(frame: traceback.FrameSummary, /) -> bool:
for filename_pattern, name_pattern in patterns_to_exclude:
if re.search(pattern=filename_pattern, string=frame.filename) and re.search(
pattern=name_pattern, string=frame.name
):
return True
return False
return [frame for frame in frames if not should_exclude(frame)]
def _call_layer(self, layer: Module, name: str, /, *args: Any) -> Any:
try:
return layer(*args)
except Exception as e:
pretty_print = self.debug_repr(layer_name)
raise ValueError(f"Error in layer {layer_name}, args:\n {args}\n \n{pretty_print}") from e
exc_type, _, exc_traceback = sys.exc_info()
assert exc_type
tb_list = traceback.extract_tb(tb=exc_traceback)
filtered_tb_list = self._filter_traceback(*tb_list)
formatted_tb = "".join(traceback.format_list(extracted_list=filtered_tb_list))
pretty_args = Chain._pretty_print_args(args)
error_tree = self._show_error_in_tree(name)
exception_str = re.sub(pattern=r"\n\s*\n", repl="\n", string=str(object=e))
message = f"{formatted_tb}\n{exception_str}\n---------------\n{error_tree}\n{pretty_args}"
if "Error" not in exception_str:
message = f"{exc_type.__name__}:\n {message}"
raise ChainError(message) from None
def forward(self, *args: Any) -> Any:
result: tuple[Any] | Any = None
intermediate_args: tuple[Any, ...] = args
for name, layer in self._modules.items():
result = self.call_layer(layer, name, *intermediate_args)
result = self._call_layer(layer, name, *intermediate_args)
intermediate_args = (result,) if not isinstance(result, tuple) else result
self._reset_context()
@ -409,7 +485,7 @@ class Parallel(Chain):
_tag = "PAR"
def forward(self, *args: Any) -> tuple[Tensor, ...]:
return tuple([self.call_layer(module, name, *args) for name, module in self._modules.items()])
return tuple([self._call_layer(module, name, *args) for name, module in self._modules.items()])
def _show_only_tag(self) -> bool:
return self.__class__ == Parallel
@ -421,7 +497,7 @@ class Distribute(Chain):
def forward(self, *args: Any) -> tuple[Tensor, ...]:
n, m = len(args), len(self._modules)
assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})."
return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
return tuple([self._call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
def _show_only_tag(self) -> bool:
return self.__class__ == Distribute

View file

@ -59,7 +59,7 @@ class Module(TorchModule):
def pretty_print(self, depth: int = -1) -> None:
tree = ModuleTree(module=self)
print(tree.generate_tree_repr(tree.root, is_root=True, depth=depth))
print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]:
"""Return a dictionary of basic attributes of the module.
@ -182,10 +182,22 @@ class ModuleTree:
return f"{self.__class__.__name__}(root={self.root['value']})"
def __repr__(self) -> str:
return self.generate_tree_repr(node=self.root, is_root=True, depth=7)
return self._generate_tree_repr(self.root, is_root=True, depth=7)
def generate_tree_repr(
self, node: TreeNode, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
def __iter__(self) -> Generator[TreeNode, None, None]:
for child in self.root["children"]:
yield child
@classmethod
def shorten_tree_repr(cls, tree_repr: str, /, line_index: int = 0, max_lines: int = 20) -> str:
"""Shorten the tree representation to a given number of lines around a given line index."""
lines = tree_repr.split(sep="\n")
start_idx = max(0, line_index - max_lines // 2)
end_idx = min(len(lines), line_index + max_lines // 2 + 1)
return "\n".join(lines[start_idx:end_idx])
def _generate_tree_repr(
self, node: TreeNode, /, *, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
) -> str:
if depth == 0 and node["children"]:
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."
@ -211,7 +223,7 @@ class ModuleTree:
else:
child_value = child["value"]
child_str = self.generate_tree_repr(
child_str = self._generate_tree_repr(
{"value": child_value, "class_name": child["class_name"], "children": child["children"]},
prefix=prefix + new_prefix,
is_last=i == len(node["children"]) - 1,

View file

@ -137,3 +137,23 @@ def load_metadata_from_safetensors(path: Path | str) -> dict[str, str] | None:
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
_save_file(tensors, path, metadata) # type: ignore
def summarize_tensor(tensor: torch.Tensor, /) -> str:
return (
"Tensor("
+ ", ".join(
[
f"shape=({', '.join(map(str, tensor.shape))})",
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
f"device={tensor.device}",
f"min={tensor.min():.2f}", # type: ignore
f"max={tensor.max():.2f}", # type: ignore
f"mean={tensor.mean():.2f}",
f"std={tensor.std():.2f}",
f"norm={norm(x=tensor):.2f}",
f"grad={tensor.requires_grad}",
]
)
+ ")"
)

View file

@ -226,3 +226,19 @@ def test_setattr_dont_register() -> None:
chain.foo = fl.Linear(in_features=1, out_features=1)
assert module_keys(chain=chain) == ["Linear_1", "Linear_2"]
EXPECTED_TREE = (
"(CHAIN)\n ├── Linear(in_features=1, out_features=1) (x2)\n └── (CHAIN)\n ├── Linear(in_features=1,"
" out_features=1) #1\n └── Linear(in_features=2, out_features=1) #2"
)
def test_debug_print() -> None:
chain = fl.Chain(
fl.Linear(1, 1),
fl.Linear(1, 1),
fl.Chain(fl.Linear(1, 1), fl.Linear(2, 1)),
)
assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage]