improve debug print for chains

This commit is contained in:
Benjamin Trom 2023-10-10 15:02:52 +02:00
parent a663375dc7
commit 0024191c58
4 changed files with 148 additions and 24 deletions

View file

@ -1,10 +1,15 @@
from collections import defaultdict
import inspect import inspect
import re
import sys
import traceback
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
import torch import torch
from torch import Tensor, cat, device as Device, dtype as DType from torch import Tensor, cat, device as Device, dtype as DType
from refiners.fluxion.layers.basics import Identity from refiners.fluxion.layers.basics import Identity
from refiners.fluxion.layers.module import Module, ContextModule, WeightedModule from refiners.fluxion.layers.module import Module, ContextModule, ModuleTree, WeightedModule
from refiners.fluxion.context import Contexts, ContextProvider from refiners.fluxion.context import Contexts, ContextProvider
from refiners.fluxion.utils import summarize_tensor
T = TypeVar("T", bound=Module) T = TypeVar("T", bound=Module)
@ -109,6 +114,13 @@ def structural_copy(m: T) -> T:
return m.structural_copy() if isinstance(m, ContextModule) else m return m.structural_copy() if isinstance(m, ContextModule) else m
class ChainError(RuntimeError):
"""Exception raised when an error occurs during the execution of a Chain."""
def __init__(self, message: str, /) -> None:
super().__init__(message)
class Chain(ContextModule): class Chain(ContextModule):
_modules: dict[str, Module] _modules: dict[str, Module]
_provider: ContextProvider _provider: ContextProvider
@ -173,34 +185,98 @@ class Chain(ContextModule):
self._provider.set_context(context, value) self._provider.set_context(context, value)
self._register_provider() self._register_provider()
def debug_repr(self, layer_name: str = "") -> str: def _show_error_in_tree(self, name: str, /, max_lines: int = 20) -> str:
lines: list[str] = [] tree = ModuleTree(module=self)
tab = " " classname_counter: dict[str, int] = defaultdict(int)
tab_length = 0 first_ancestor = self.get_parents()[-1] if self.get_parents() else self
for i, parent in enumerate(self.get_parents()[::-1]):
lines.append(f"{tab*tab_length}{'└─ ' if i else ''}{parent.__class__.__name__}")
tab_length += 1
lines.append(f"{tab*tab_length}└─ {self.__class__.__name__}") def find_state_dict_key(module: Module, /) -> str | None:
for key, layer in module.named_modules():
if layer == self:
return ".".join((key, name))
return None
for name, _ in self._modules.items(): for child in tree:
error_arrow = "⚠️" if name == layer_name else "" classname, count = name.rsplit(sep="_", maxsplit=1) if "_" in name else (name, "1")
lines.append(f"{tab*tab_length} | {name} {error_arrow}") if child["class_name"] == classname:
classname_counter[classname] += 1
if classname_counter[classname] == int(count):
state_dict_key = find_state_dict_key(first_ancestor)
child["value"] = f">>> {child['value']} | {state_dict_key}"
break
return "\n".join(lines) tree_repr = tree._generate_tree_repr(tree.root, depth=3) # type: ignore[reportPrivateUsage]
def call_layer(self, layer: Module, layer_name: str, *args: Any): lines = tree_repr.split(sep="\n")
error_line_idx = next((idx for idx, line in enumerate(iterable=lines) if line.startswith(">>>")), 0)
return ModuleTree.shorten_tree_repr(tree_repr, line_index=error_line_idx, max_lines=max_lines)
@staticmethod
def _pretty_print_args(*args: Any) -> str:
"""
Flatten nested tuples and print tensors with their shape and other informations.
"""
def _flatten_tuple(t: Tensor | tuple[Any, ...], /) -> list[Any]:
if isinstance(t, tuple):
return [item for subtuple in t for item in _flatten_tuple(subtuple)]
else:
return [t]
flat_args = _flatten_tuple(args)
return "\n".join(
[
f"{idx}: {summarize_tensor(arg) if isinstance(arg, Tensor) else arg}"
for idx, arg in enumerate(iterable=flat_args)
]
)
def _filter_traceback(self, *frames: traceback.FrameSummary) -> list[traceback.FrameSummary]:
patterns_to_exclude = [
(r"torch/nn/modules/", r"^_call_impl$"),
(r"torch/nn/functional\.py", r""),
(r"refiners/fluxion/layers/", r"^_call_layer$"),
(r"refiners/fluxion/layers/", r"^forward$"),
(r"refiners/fluxion/layers/chain\.py", r""),
(r"", r"^_"),
]
def should_exclude(frame: traceback.FrameSummary, /) -> bool:
for filename_pattern, name_pattern in patterns_to_exclude:
if re.search(pattern=filename_pattern, string=frame.filename) and re.search(
pattern=name_pattern, string=frame.name
):
return True
return False
return [frame for frame in frames if not should_exclude(frame)]
def _call_layer(self, layer: Module, name: str, /, *args: Any) -> Any:
try: try:
return layer(*args) return layer(*args)
except Exception as e: except Exception as e:
pretty_print = self.debug_repr(layer_name) exc_type, _, exc_traceback = sys.exc_info()
raise ValueError(f"Error in layer {layer_name}, args:\n {args}\n \n{pretty_print}") from e assert exc_type
tb_list = traceback.extract_tb(tb=exc_traceback)
filtered_tb_list = self._filter_traceback(*tb_list)
formatted_tb = "".join(traceback.format_list(extracted_list=filtered_tb_list))
pretty_args = Chain._pretty_print_args(args)
error_tree = self._show_error_in_tree(name)
exception_str = re.sub(pattern=r"\n\s*\n", repl="\n", string=str(object=e))
message = f"{formatted_tb}\n{exception_str}\n---------------\n{error_tree}\n{pretty_args}"
if "Error" not in exception_str:
message = f"{exc_type.__name__}:\n {message}"
raise ChainError(message) from None
def forward(self, *args: Any) -> Any: def forward(self, *args: Any) -> Any:
result: tuple[Any] | Any = None result: tuple[Any] | Any = None
intermediate_args: tuple[Any, ...] = args intermediate_args: tuple[Any, ...] = args
for name, layer in self._modules.items(): for name, layer in self._modules.items():
result = self.call_layer(layer, name, *intermediate_args) result = self._call_layer(layer, name, *intermediate_args)
intermediate_args = (result,) if not isinstance(result, tuple) else result intermediate_args = (result,) if not isinstance(result, tuple) else result
self._reset_context() self._reset_context()
@ -409,7 +485,7 @@ class Parallel(Chain):
_tag = "PAR" _tag = "PAR"
def forward(self, *args: Any) -> tuple[Tensor, ...]: def forward(self, *args: Any) -> tuple[Tensor, ...]:
return tuple([self.call_layer(module, name, *args) for name, module in self._modules.items()]) return tuple([self._call_layer(module, name, *args) for name, module in self._modules.items()])
def _show_only_tag(self) -> bool: def _show_only_tag(self) -> bool:
return self.__class__ == Parallel return self.__class__ == Parallel
@ -421,7 +497,7 @@ class Distribute(Chain):
def forward(self, *args: Any) -> tuple[Tensor, ...]: def forward(self, *args: Any) -> tuple[Tensor, ...]:
n, m = len(args), len(self._modules) n, m = len(args), len(self._modules)
assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})." assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})."
return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())]) return tuple([self._call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
def _show_only_tag(self) -> bool: def _show_only_tag(self) -> bool:
return self.__class__ == Distribute return self.__class__ == Distribute

View file

@ -59,7 +59,7 @@ class Module(TorchModule):
def pretty_print(self, depth: int = -1) -> None: def pretty_print(self, depth: int = -1) -> None:
tree = ModuleTree(module=self) tree = ModuleTree(module=self)
print(tree.generate_tree_repr(tree.root, is_root=True, depth=depth)) print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]: def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]:
"""Return a dictionary of basic attributes of the module. """Return a dictionary of basic attributes of the module.
@ -182,10 +182,22 @@ class ModuleTree:
return f"{self.__class__.__name__}(root={self.root['value']})" return f"{self.__class__.__name__}(root={self.root['value']})"
def __repr__(self) -> str: def __repr__(self) -> str:
return self.generate_tree_repr(node=self.root, is_root=True, depth=7) return self._generate_tree_repr(self.root, is_root=True, depth=7)
def generate_tree_repr( def __iter__(self) -> Generator[TreeNode, None, None]:
self, node: TreeNode, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1 for child in self.root["children"]:
yield child
@classmethod
def shorten_tree_repr(cls, tree_repr: str, /, line_index: int = 0, max_lines: int = 20) -> str:
"""Shorten the tree representation to a given number of lines around a given line index."""
lines = tree_repr.split(sep="\n")
start_idx = max(0, line_index - max_lines // 2)
end_idx = min(len(lines), line_index + max_lines // 2 + 1)
return "\n".join(lines[start_idx:end_idx])
def _generate_tree_repr(
self, node: TreeNode, /, *, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
) -> str: ) -> str:
if depth == 0 and node["children"]: if depth == 0 and node["children"]:
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..." return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."
@ -211,7 +223,7 @@ class ModuleTree:
else: else:
child_value = child["value"] child_value = child["value"]
child_str = self.generate_tree_repr( child_str = self._generate_tree_repr(
{"value": child_value, "class_name": child["class_name"], "children": child["children"]}, {"value": child_value, "class_name": child["class_name"], "children": child["children"]},
prefix=prefix + new_prefix, prefix=prefix + new_prefix,
is_last=i == len(node["children"]) - 1, is_last=i == len(node["children"]) - 1,

View file

@ -137,3 +137,23 @@ def load_metadata_from_safetensors(path: Path | str) -> dict[str, str] | None:
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None: def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
_save_file(tensors, path, metadata) # type: ignore _save_file(tensors, path, metadata) # type: ignore
def summarize_tensor(tensor: torch.Tensor, /) -> str:
return (
"Tensor("
+ ", ".join(
[
f"shape=({', '.join(map(str, tensor.shape))})",
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
f"device={tensor.device}",
f"min={tensor.min():.2f}", # type: ignore
f"max={tensor.max():.2f}", # type: ignore
f"mean={tensor.mean():.2f}",
f"std={tensor.std():.2f}",
f"norm={norm(x=tensor):.2f}",
f"grad={tensor.requires_grad}",
]
)
+ ")"
)

View file

@ -226,3 +226,19 @@ def test_setattr_dont_register() -> None:
chain.foo = fl.Linear(in_features=1, out_features=1) chain.foo = fl.Linear(in_features=1, out_features=1)
assert module_keys(chain=chain) == ["Linear_1", "Linear_2"] assert module_keys(chain=chain) == ["Linear_1", "Linear_2"]
EXPECTED_TREE = (
"(CHAIN)\n ├── Linear(in_features=1, out_features=1) (x2)\n └── (CHAIN)\n ├── Linear(in_features=1,"
" out_features=1) #1\n └── Linear(in_features=2, out_features=1) #2"
)
def test_debug_print() -> None:
chain = fl.Chain(
fl.Linear(1, 1),
fl.Linear(1, 1),
fl.Chain(fl.Linear(1, 1), fl.Linear(2, 1)),
)
assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage]