mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
improve debug print for chains
This commit is contained in:
parent
a663375dc7
commit
0024191c58
|
@ -1,10 +1,15 @@
|
|||
from collections import defaultdict
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
|
||||
import torch
|
||||
from torch import Tensor, cat, device as Device, dtype as DType
|
||||
from refiners.fluxion.layers.basics import Identity
|
||||
from refiners.fluxion.layers.module import Module, ContextModule, WeightedModule
|
||||
from refiners.fluxion.layers.module import Module, ContextModule, ModuleTree, WeightedModule
|
||||
from refiners.fluxion.context import Contexts, ContextProvider
|
||||
from refiners.fluxion.utils import summarize_tensor
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Module)
|
||||
|
@ -109,6 +114,13 @@ def structural_copy(m: T) -> T:
|
|||
return m.structural_copy() if isinstance(m, ContextModule) else m
|
||||
|
||||
|
||||
class ChainError(RuntimeError):
|
||||
"""Exception raised when an error occurs during the execution of a Chain."""
|
||||
|
||||
def __init__(self, message: str, /) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class Chain(ContextModule):
|
||||
_modules: dict[str, Module]
|
||||
_provider: ContextProvider
|
||||
|
@ -173,34 +185,98 @@ class Chain(ContextModule):
|
|||
self._provider.set_context(context, value)
|
||||
self._register_provider()
|
||||
|
||||
def debug_repr(self, layer_name: str = "") -> str:
|
||||
lines: list[str] = []
|
||||
tab = " "
|
||||
tab_length = 0
|
||||
for i, parent in enumerate(self.get_parents()[::-1]):
|
||||
lines.append(f"{tab*tab_length}{'└─ ' if i else ''}{parent.__class__.__name__}")
|
||||
tab_length += 1
|
||||
def _show_error_in_tree(self, name: str, /, max_lines: int = 20) -> str:
|
||||
tree = ModuleTree(module=self)
|
||||
classname_counter: dict[str, int] = defaultdict(int)
|
||||
first_ancestor = self.get_parents()[-1] if self.get_parents() else self
|
||||
|
||||
lines.append(f"{tab*tab_length}└─ {self.__class__.__name__}")
|
||||
def find_state_dict_key(module: Module, /) -> str | None:
|
||||
for key, layer in module.named_modules():
|
||||
if layer == self:
|
||||
return ".".join((key, name))
|
||||
return None
|
||||
|
||||
for name, _ in self._modules.items():
|
||||
error_arrow = "⚠️" if name == layer_name else ""
|
||||
lines.append(f"{tab*tab_length} | {name} {error_arrow}")
|
||||
for child in tree:
|
||||
classname, count = name.rsplit(sep="_", maxsplit=1) if "_" in name else (name, "1")
|
||||
if child["class_name"] == classname:
|
||||
classname_counter[classname] += 1
|
||||
if classname_counter[classname] == int(count):
|
||||
state_dict_key = find_state_dict_key(first_ancestor)
|
||||
child["value"] = f">>> {child['value']} | {state_dict_key}"
|
||||
break
|
||||
|
||||
return "\n".join(lines)
|
||||
tree_repr = tree._generate_tree_repr(tree.root, depth=3) # type: ignore[reportPrivateUsage]
|
||||
|
||||
def call_layer(self, layer: Module, layer_name: str, *args: Any):
|
||||
lines = tree_repr.split(sep="\n")
|
||||
error_line_idx = next((idx for idx, line in enumerate(iterable=lines) if line.startswith(">>>")), 0)
|
||||
|
||||
return ModuleTree.shorten_tree_repr(tree_repr, line_index=error_line_idx, max_lines=max_lines)
|
||||
|
||||
@staticmethod
|
||||
def _pretty_print_args(*args: Any) -> str:
|
||||
"""
|
||||
Flatten nested tuples and print tensors with their shape and other informations.
|
||||
"""
|
||||
|
||||
def _flatten_tuple(t: Tensor | tuple[Any, ...], /) -> list[Any]:
|
||||
if isinstance(t, tuple):
|
||||
return [item for subtuple in t for item in _flatten_tuple(subtuple)]
|
||||
else:
|
||||
return [t]
|
||||
|
||||
flat_args = _flatten_tuple(args)
|
||||
|
||||
return "\n".join(
|
||||
[
|
||||
f"{idx}: {summarize_tensor(arg) if isinstance(arg, Tensor) else arg}"
|
||||
for idx, arg in enumerate(iterable=flat_args)
|
||||
]
|
||||
)
|
||||
|
||||
def _filter_traceback(self, *frames: traceback.FrameSummary) -> list[traceback.FrameSummary]:
|
||||
patterns_to_exclude = [
|
||||
(r"torch/nn/modules/", r"^_call_impl$"),
|
||||
(r"torch/nn/functional\.py", r""),
|
||||
(r"refiners/fluxion/layers/", r"^_call_layer$"),
|
||||
(r"refiners/fluxion/layers/", r"^forward$"),
|
||||
(r"refiners/fluxion/layers/chain\.py", r""),
|
||||
(r"", r"^_"),
|
||||
]
|
||||
|
||||
def should_exclude(frame: traceback.FrameSummary, /) -> bool:
|
||||
for filename_pattern, name_pattern in patterns_to_exclude:
|
||||
if re.search(pattern=filename_pattern, string=frame.filename) and re.search(
|
||||
pattern=name_pattern, string=frame.name
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
return [frame for frame in frames if not should_exclude(frame)]
|
||||
|
||||
def _call_layer(self, layer: Module, name: str, /, *args: Any) -> Any:
|
||||
try:
|
||||
return layer(*args)
|
||||
except Exception as e:
|
||||
pretty_print = self.debug_repr(layer_name)
|
||||
raise ValueError(f"Error in layer {layer_name}, args:\n {args}\n \n{pretty_print}") from e
|
||||
exc_type, _, exc_traceback = sys.exc_info()
|
||||
assert exc_type
|
||||
tb_list = traceback.extract_tb(tb=exc_traceback)
|
||||
filtered_tb_list = self._filter_traceback(*tb_list)
|
||||
formatted_tb = "".join(traceback.format_list(extracted_list=filtered_tb_list))
|
||||
pretty_args = Chain._pretty_print_args(args)
|
||||
error_tree = self._show_error_in_tree(name)
|
||||
|
||||
exception_str = re.sub(pattern=r"\n\s*\n", repl="\n", string=str(object=e))
|
||||
message = f"{formatted_tb}\n{exception_str}\n---------------\n{error_tree}\n{pretty_args}"
|
||||
if "Error" not in exception_str:
|
||||
message = f"{exc_type.__name__}:\n {message}"
|
||||
|
||||
raise ChainError(message) from None
|
||||
|
||||
def forward(self, *args: Any) -> Any:
|
||||
result: tuple[Any] | Any = None
|
||||
intermediate_args: tuple[Any, ...] = args
|
||||
for name, layer in self._modules.items():
|
||||
result = self.call_layer(layer, name, *intermediate_args)
|
||||
result = self._call_layer(layer, name, *intermediate_args)
|
||||
intermediate_args = (result,) if not isinstance(result, tuple) else result
|
||||
|
||||
self._reset_context()
|
||||
|
@ -409,7 +485,7 @@ class Parallel(Chain):
|
|||
_tag = "PAR"
|
||||
|
||||
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||
return tuple([self.call_layer(module, name, *args) for name, module in self._modules.items()])
|
||||
return tuple([self._call_layer(module, name, *args) for name, module in self._modules.items()])
|
||||
|
||||
def _show_only_tag(self) -> bool:
|
||||
return self.__class__ == Parallel
|
||||
|
@ -421,7 +497,7 @@ class Distribute(Chain):
|
|||
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||
n, m = len(args), len(self._modules)
|
||||
assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})."
|
||||
return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
|
||||
return tuple([self._call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
|
||||
|
||||
def _show_only_tag(self) -> bool:
|
||||
return self.__class__ == Distribute
|
||||
|
|
|
@ -59,7 +59,7 @@ class Module(TorchModule):
|
|||
|
||||
def pretty_print(self, depth: int = -1) -> None:
|
||||
tree = ModuleTree(module=self)
|
||||
print(tree.generate_tree_repr(tree.root, is_root=True, depth=depth))
|
||||
print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]
|
||||
|
||||
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]:
|
||||
"""Return a dictionary of basic attributes of the module.
|
||||
|
@ -182,10 +182,22 @@ class ModuleTree:
|
|||
return f"{self.__class__.__name__}(root={self.root['value']})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.generate_tree_repr(node=self.root, is_root=True, depth=7)
|
||||
return self._generate_tree_repr(self.root, is_root=True, depth=7)
|
||||
|
||||
def generate_tree_repr(
|
||||
self, node: TreeNode, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
|
||||
def __iter__(self) -> Generator[TreeNode, None, None]:
|
||||
for child in self.root["children"]:
|
||||
yield child
|
||||
|
||||
@classmethod
|
||||
def shorten_tree_repr(cls, tree_repr: str, /, line_index: int = 0, max_lines: int = 20) -> str:
|
||||
"""Shorten the tree representation to a given number of lines around a given line index."""
|
||||
lines = tree_repr.split(sep="\n")
|
||||
start_idx = max(0, line_index - max_lines // 2)
|
||||
end_idx = min(len(lines), line_index + max_lines // 2 + 1)
|
||||
return "\n".join(lines[start_idx:end_idx])
|
||||
|
||||
def _generate_tree_repr(
|
||||
self, node: TreeNode, /, *, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
|
||||
) -> str:
|
||||
if depth == 0 and node["children"]:
|
||||
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."
|
||||
|
@ -211,7 +223,7 @@ class ModuleTree:
|
|||
else:
|
||||
child_value = child["value"]
|
||||
|
||||
child_str = self.generate_tree_repr(
|
||||
child_str = self._generate_tree_repr(
|
||||
{"value": child_value, "class_name": child["class_name"], "children": child["children"]},
|
||||
prefix=prefix + new_prefix,
|
||||
is_last=i == len(node["children"]) - 1,
|
||||
|
|
|
@ -137,3 +137,23 @@ def load_metadata_from_safetensors(path: Path | str) -> dict[str, str] | None:
|
|||
|
||||
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
|
||||
_save_file(tensors, path, metadata) # type: ignore
|
||||
|
||||
|
||||
def summarize_tensor(tensor: torch.Tensor, /) -> str:
|
||||
return (
|
||||
"Tensor("
|
||||
+ ", ".join(
|
||||
[
|
||||
f"shape=({', '.join(map(str, tensor.shape))})",
|
||||
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
|
||||
f"device={tensor.device}",
|
||||
f"min={tensor.min():.2f}", # type: ignore
|
||||
f"max={tensor.max():.2f}", # type: ignore
|
||||
f"mean={tensor.mean():.2f}",
|
||||
f"std={tensor.std():.2f}",
|
||||
f"norm={norm(x=tensor):.2f}",
|
||||
f"grad={tensor.requires_grad}",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
|
|
|
@ -226,3 +226,19 @@ def test_setattr_dont_register() -> None:
|
|||
chain.foo = fl.Linear(in_features=1, out_features=1)
|
||||
|
||||
assert module_keys(chain=chain) == ["Linear_1", "Linear_2"]
|
||||
|
||||
|
||||
EXPECTED_TREE = (
|
||||
"(CHAIN)\n ├── Linear(in_features=1, out_features=1) (x2)\n └── (CHAIN)\n ├── Linear(in_features=1,"
|
||||
" out_features=1) #1\n └── Linear(in_features=2, out_features=1) #2"
|
||||
)
|
||||
|
||||
|
||||
def test_debug_print() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Linear(1, 1),
|
||||
fl.Linear(1, 1),
|
||||
fl.Chain(fl.Linear(1, 1), fl.Linear(2, 1)),
|
||||
)
|
||||
|
||||
assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage]
|
||||
|
|
Loading…
Reference in a new issue