mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 09:08:14 +00:00
improve debug print for chains
This commit is contained in:
parent
a663375dc7
commit
0024191c58
|
@ -1,10 +1,15 @@
|
||||||
|
from collections import defaultdict
|
||||||
import inspect
|
import inspect
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
|
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, cat, device as Device, dtype as DType
|
from torch import Tensor, cat, device as Device, dtype as DType
|
||||||
from refiners.fluxion.layers.basics import Identity
|
from refiners.fluxion.layers.basics import Identity
|
||||||
from refiners.fluxion.layers.module import Module, ContextModule, WeightedModule
|
from refiners.fluxion.layers.module import Module, ContextModule, ModuleTree, WeightedModule
|
||||||
from refiners.fluxion.context import Contexts, ContextProvider
|
from refiners.fluxion.context import Contexts, ContextProvider
|
||||||
|
from refiners.fluxion.utils import summarize_tensor
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=Module)
|
T = TypeVar("T", bound=Module)
|
||||||
|
@ -109,6 +114,13 @@ def structural_copy(m: T) -> T:
|
||||||
return m.structural_copy() if isinstance(m, ContextModule) else m
|
return m.structural_copy() if isinstance(m, ContextModule) else m
|
||||||
|
|
||||||
|
|
||||||
|
class ChainError(RuntimeError):
|
||||||
|
"""Exception raised when an error occurs during the execution of a Chain."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, /) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class Chain(ContextModule):
|
class Chain(ContextModule):
|
||||||
_modules: dict[str, Module]
|
_modules: dict[str, Module]
|
||||||
_provider: ContextProvider
|
_provider: ContextProvider
|
||||||
|
@ -173,34 +185,98 @@ class Chain(ContextModule):
|
||||||
self._provider.set_context(context, value)
|
self._provider.set_context(context, value)
|
||||||
self._register_provider()
|
self._register_provider()
|
||||||
|
|
||||||
def debug_repr(self, layer_name: str = "") -> str:
|
def _show_error_in_tree(self, name: str, /, max_lines: int = 20) -> str:
|
||||||
lines: list[str] = []
|
tree = ModuleTree(module=self)
|
||||||
tab = " "
|
classname_counter: dict[str, int] = defaultdict(int)
|
||||||
tab_length = 0
|
first_ancestor = self.get_parents()[-1] if self.get_parents() else self
|
||||||
for i, parent in enumerate(self.get_parents()[::-1]):
|
|
||||||
lines.append(f"{tab*tab_length}{'└─ ' if i else ''}{parent.__class__.__name__}")
|
|
||||||
tab_length += 1
|
|
||||||
|
|
||||||
lines.append(f"{tab*tab_length}└─ {self.__class__.__name__}")
|
def find_state_dict_key(module: Module, /) -> str | None:
|
||||||
|
for key, layer in module.named_modules():
|
||||||
|
if layer == self:
|
||||||
|
return ".".join((key, name))
|
||||||
|
return None
|
||||||
|
|
||||||
for name, _ in self._modules.items():
|
for child in tree:
|
||||||
error_arrow = "⚠️" if name == layer_name else ""
|
classname, count = name.rsplit(sep="_", maxsplit=1) if "_" in name else (name, "1")
|
||||||
lines.append(f"{tab*tab_length} | {name} {error_arrow}")
|
if child["class_name"] == classname:
|
||||||
|
classname_counter[classname] += 1
|
||||||
|
if classname_counter[classname] == int(count):
|
||||||
|
state_dict_key = find_state_dict_key(first_ancestor)
|
||||||
|
child["value"] = f">>> {child['value']} | {state_dict_key}"
|
||||||
|
break
|
||||||
|
|
||||||
return "\n".join(lines)
|
tree_repr = tree._generate_tree_repr(tree.root, depth=3) # type: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
def call_layer(self, layer: Module, layer_name: str, *args: Any):
|
lines = tree_repr.split(sep="\n")
|
||||||
|
error_line_idx = next((idx for idx, line in enumerate(iterable=lines) if line.startswith(">>>")), 0)
|
||||||
|
|
||||||
|
return ModuleTree.shorten_tree_repr(tree_repr, line_index=error_line_idx, max_lines=max_lines)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pretty_print_args(*args: Any) -> str:
|
||||||
|
"""
|
||||||
|
Flatten nested tuples and print tensors with their shape and other informations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _flatten_tuple(t: Tensor | tuple[Any, ...], /) -> list[Any]:
|
||||||
|
if isinstance(t, tuple):
|
||||||
|
return [item for subtuple in t for item in _flatten_tuple(subtuple)]
|
||||||
|
else:
|
||||||
|
return [t]
|
||||||
|
|
||||||
|
flat_args = _flatten_tuple(args)
|
||||||
|
|
||||||
|
return "\n".join(
|
||||||
|
[
|
||||||
|
f"{idx}: {summarize_tensor(arg) if isinstance(arg, Tensor) else arg}"
|
||||||
|
for idx, arg in enumerate(iterable=flat_args)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _filter_traceback(self, *frames: traceback.FrameSummary) -> list[traceback.FrameSummary]:
|
||||||
|
patterns_to_exclude = [
|
||||||
|
(r"torch/nn/modules/", r"^_call_impl$"),
|
||||||
|
(r"torch/nn/functional\.py", r""),
|
||||||
|
(r"refiners/fluxion/layers/", r"^_call_layer$"),
|
||||||
|
(r"refiners/fluxion/layers/", r"^forward$"),
|
||||||
|
(r"refiners/fluxion/layers/chain\.py", r""),
|
||||||
|
(r"", r"^_"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def should_exclude(frame: traceback.FrameSummary, /) -> bool:
|
||||||
|
for filename_pattern, name_pattern in patterns_to_exclude:
|
||||||
|
if re.search(pattern=filename_pattern, string=frame.filename) and re.search(
|
||||||
|
pattern=name_pattern, string=frame.name
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
return [frame for frame in frames if not should_exclude(frame)]
|
||||||
|
|
||||||
|
def _call_layer(self, layer: Module, name: str, /, *args: Any) -> Any:
|
||||||
try:
|
try:
|
||||||
return layer(*args)
|
return layer(*args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pretty_print = self.debug_repr(layer_name)
|
exc_type, _, exc_traceback = sys.exc_info()
|
||||||
raise ValueError(f"Error in layer {layer_name}, args:\n {args}\n \n{pretty_print}") from e
|
assert exc_type
|
||||||
|
tb_list = traceback.extract_tb(tb=exc_traceback)
|
||||||
|
filtered_tb_list = self._filter_traceback(*tb_list)
|
||||||
|
formatted_tb = "".join(traceback.format_list(extracted_list=filtered_tb_list))
|
||||||
|
pretty_args = Chain._pretty_print_args(args)
|
||||||
|
error_tree = self._show_error_in_tree(name)
|
||||||
|
|
||||||
|
exception_str = re.sub(pattern=r"\n\s*\n", repl="\n", string=str(object=e))
|
||||||
|
message = f"{formatted_tb}\n{exception_str}\n---------------\n{error_tree}\n{pretty_args}"
|
||||||
|
if "Error" not in exception_str:
|
||||||
|
message = f"{exc_type.__name__}:\n {message}"
|
||||||
|
|
||||||
|
raise ChainError(message) from None
|
||||||
|
|
||||||
def forward(self, *args: Any) -> Any:
|
def forward(self, *args: Any) -> Any:
|
||||||
result: tuple[Any] | Any = None
|
result: tuple[Any] | Any = None
|
||||||
intermediate_args: tuple[Any, ...] = args
|
intermediate_args: tuple[Any, ...] = args
|
||||||
for name, layer in self._modules.items():
|
for name, layer in self._modules.items():
|
||||||
result = self.call_layer(layer, name, *intermediate_args)
|
result = self._call_layer(layer, name, *intermediate_args)
|
||||||
intermediate_args = (result,) if not isinstance(result, tuple) else result
|
intermediate_args = (result,) if not isinstance(result, tuple) else result
|
||||||
|
|
||||||
self._reset_context()
|
self._reset_context()
|
||||||
|
@ -409,7 +485,7 @@ class Parallel(Chain):
|
||||||
_tag = "PAR"
|
_tag = "PAR"
|
||||||
|
|
||||||
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||||
return tuple([self.call_layer(module, name, *args) for name, module in self._modules.items()])
|
return tuple([self._call_layer(module, name, *args) for name, module in self._modules.items()])
|
||||||
|
|
||||||
def _show_only_tag(self) -> bool:
|
def _show_only_tag(self) -> bool:
|
||||||
return self.__class__ == Parallel
|
return self.__class__ == Parallel
|
||||||
|
@ -421,7 +497,7 @@ class Distribute(Chain):
|
||||||
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||||
n, m = len(args), len(self._modules)
|
n, m = len(args), len(self._modules)
|
||||||
assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})."
|
assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})."
|
||||||
return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
|
return tuple([self._call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
|
||||||
|
|
||||||
def _show_only_tag(self) -> bool:
|
def _show_only_tag(self) -> bool:
|
||||||
return self.__class__ == Distribute
|
return self.__class__ == Distribute
|
||||||
|
|
|
@ -59,7 +59,7 @@ class Module(TorchModule):
|
||||||
|
|
||||||
def pretty_print(self, depth: int = -1) -> None:
|
def pretty_print(self, depth: int = -1) -> None:
|
||||||
tree = ModuleTree(module=self)
|
tree = ModuleTree(module=self)
|
||||||
print(tree.generate_tree_repr(tree.root, is_root=True, depth=depth))
|
print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]:
|
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]:
|
||||||
"""Return a dictionary of basic attributes of the module.
|
"""Return a dictionary of basic attributes of the module.
|
||||||
|
@ -182,10 +182,22 @@ class ModuleTree:
|
||||||
return f"{self.__class__.__name__}(root={self.root['value']})"
|
return f"{self.__class__.__name__}(root={self.root['value']})"
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return self.generate_tree_repr(node=self.root, is_root=True, depth=7)
|
return self._generate_tree_repr(self.root, is_root=True, depth=7)
|
||||||
|
|
||||||
def generate_tree_repr(
|
def __iter__(self) -> Generator[TreeNode, None, None]:
|
||||||
self, node: TreeNode, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
|
for child in self.root["children"]:
|
||||||
|
yield child
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def shorten_tree_repr(cls, tree_repr: str, /, line_index: int = 0, max_lines: int = 20) -> str:
|
||||||
|
"""Shorten the tree representation to a given number of lines around a given line index."""
|
||||||
|
lines = tree_repr.split(sep="\n")
|
||||||
|
start_idx = max(0, line_index - max_lines // 2)
|
||||||
|
end_idx = min(len(lines), line_index + max_lines // 2 + 1)
|
||||||
|
return "\n".join(lines[start_idx:end_idx])
|
||||||
|
|
||||||
|
def _generate_tree_repr(
|
||||||
|
self, node: TreeNode, /, *, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
|
||||||
) -> str:
|
) -> str:
|
||||||
if depth == 0 and node["children"]:
|
if depth == 0 and node["children"]:
|
||||||
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."
|
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."
|
||||||
|
@ -211,7 +223,7 @@ class ModuleTree:
|
||||||
else:
|
else:
|
||||||
child_value = child["value"]
|
child_value = child["value"]
|
||||||
|
|
||||||
child_str = self.generate_tree_repr(
|
child_str = self._generate_tree_repr(
|
||||||
{"value": child_value, "class_name": child["class_name"], "children": child["children"]},
|
{"value": child_value, "class_name": child["class_name"], "children": child["children"]},
|
||||||
prefix=prefix + new_prefix,
|
prefix=prefix + new_prefix,
|
||||||
is_last=i == len(node["children"]) - 1,
|
is_last=i == len(node["children"]) - 1,
|
||||||
|
|
|
@ -137,3 +137,23 @@ def load_metadata_from_safetensors(path: Path | str) -> dict[str, str] | None:
|
||||||
|
|
||||||
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
|
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
|
||||||
_save_file(tensors, path, metadata) # type: ignore
|
_save_file(tensors, path, metadata) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_tensor(tensor: torch.Tensor, /) -> str:
|
||||||
|
return (
|
||||||
|
"Tensor("
|
||||||
|
+ ", ".join(
|
||||||
|
[
|
||||||
|
f"shape=({', '.join(map(str, tensor.shape))})",
|
||||||
|
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
|
||||||
|
f"device={tensor.device}",
|
||||||
|
f"min={tensor.min():.2f}", # type: ignore
|
||||||
|
f"max={tensor.max():.2f}", # type: ignore
|
||||||
|
f"mean={tensor.mean():.2f}",
|
||||||
|
f"std={tensor.std():.2f}",
|
||||||
|
f"norm={norm(x=tensor):.2f}",
|
||||||
|
f"grad={tensor.requires_grad}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
+ ")"
|
||||||
|
)
|
||||||
|
|
|
@ -226,3 +226,19 @@ def test_setattr_dont_register() -> None:
|
||||||
chain.foo = fl.Linear(in_features=1, out_features=1)
|
chain.foo = fl.Linear(in_features=1, out_features=1)
|
||||||
|
|
||||||
assert module_keys(chain=chain) == ["Linear_1", "Linear_2"]
|
assert module_keys(chain=chain) == ["Linear_1", "Linear_2"]
|
||||||
|
|
||||||
|
|
||||||
|
EXPECTED_TREE = (
|
||||||
|
"(CHAIN)\n ├── Linear(in_features=1, out_features=1) (x2)\n └── (CHAIN)\n ├── Linear(in_features=1,"
|
||||||
|
" out_features=1) #1\n └── Linear(in_features=2, out_features=1) #2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_debug_print() -> None:
|
||||||
|
chain = fl.Chain(
|
||||||
|
fl.Linear(1, 1),
|
||||||
|
fl.Linear(1, 1),
|
||||||
|
fl.Chain(fl.Linear(1, 1), fl.Linear(2, 1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage]
|
||||||
|
|
Loading…
Reference in a new issue