mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 09:08:14 +00:00
(doc/fluxion/module) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
c31da03bad
commit
e3238a6af5
|
@ -20,6 +20,8 @@ BasicType = str | float | int | bool
|
||||||
|
|
||||||
|
|
||||||
class Module(TorchModule):
|
class Module(TorchModule):
|
||||||
|
"""A wrapper around [`torch.nn.Module`][torch.nn.Module]."""
|
||||||
|
|
||||||
_parameters: dict[str, Any]
|
_parameters: dict[str, Any]
|
||||||
_buffers: dict[str, Any]
|
_buffers: dict[str, Any]
|
||||||
_tag: str = ""
|
_tag: str = ""
|
||||||
|
@ -34,14 +36,38 @@ class Module(TorchModule):
|
||||||
return super().__setattr__(name=name, value=value)
|
return super().__setattr__(name=name, value=value)
|
||||||
|
|
||||||
def load_from_safetensors(self, tensors_path: str | Path, strict: bool = True) -> "Module":
|
def load_from_safetensors(self, tensors_path: str | Path, strict: bool = True) -> "Module":
|
||||||
|
"""Load the module's state from a SafeTensors file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensors_path: The path to the SafeTensors file.
|
||||||
|
strict: Whether to raise an error if the SafeTensors's
|
||||||
|
content doesn't map perfectly to the module's state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The module, with its state loaded from the SafeTensors file.
|
||||||
|
"""
|
||||||
state_dict = load_from_safetensors(tensors_path)
|
state_dict = load_from_safetensors(tensors_path)
|
||||||
self.load_state_dict(state_dict, strict=strict)
|
self.load_state_dict(state_dict, strict=strict)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def named_modules(self, *args: Any, **kwargs: Any) -> "Generator[tuple[str, Module], None, None]": # type: ignore
|
def named_modules(self, *args: Any, **kwargs: Any) -> "Generator[tuple[str, Module], None, None]": # type: ignore
|
||||||
|
"""Get all the sub-modules of the module.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An iterator over all the sub-modules of the module.
|
||||||
|
"""
|
||||||
return super().named_modules(*args) # type: ignore
|
return super().named_modules(*args) # type: ignore
|
||||||
|
|
||||||
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
|
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
|
||||||
|
"""Move the module to the given device and cast its parameters to the given dtype.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: The device to move the module to.
|
||||||
|
dtype: The dtype to cast the module's parameters to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The module, moved to the given device and cast to the given dtype.
|
||||||
|
"""
|
||||||
return super().to(device=device, dtype=dtype) # type: ignore
|
return super().to(device=device, dtype=dtype) # type: ignore
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
@ -56,13 +82,20 @@ class Module(TorchModule):
|
||||||
return repr(tree)
|
return repr(tree)
|
||||||
|
|
||||||
def pretty_print(self, depth: int = -1) -> None:
|
def pretty_print(self, depth: int = -1) -> None:
|
||||||
|
"""Print the module in a tree-like format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
depth: The maximum depth of the tree to print.
|
||||||
|
If negative, the whole tree is printed.
|
||||||
|
"""
|
||||||
tree = ModuleTree(module=self)
|
tree = ModuleTree(module=self)
|
||||||
print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]
|
print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType | Sequence[BasicType]]:
|
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType | Sequence[BasicType]]:
|
||||||
"""Return a dictionary of basic attributes of the module.
|
"""Return a dictionary of basic attributes of the module.
|
||||||
|
|
||||||
Basic attributes are public attributes made of basic types (int, float, str, bool) or a sequence of basic types.
|
Args:
|
||||||
|
init_attrs_only: Whether to only return attributes that are passed to the module's constructor.
|
||||||
"""
|
"""
|
||||||
sig = signature(obj=self.__init__)
|
sig = signature(obj=self.__init__)
|
||||||
init_params = set(sig.parameters.keys()) - {"self"}
|
init_params = set(sig.parameters.keys()) - {"self"}
|
||||||
|
@ -95,12 +128,12 @@ class Module(TorchModule):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
|
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
|
||||||
"""Helper for debugging purpose only.
|
"""Get the path of the module in the chain.
|
||||||
|
|
||||||
Returns the path of the module in the chain as a string.
|
Args:
|
||||||
|
parent: The parent of the module in the chain.
|
||||||
If `top` is set then the path will be relative to `top`,
|
top: The top module of the chain.
|
||||||
otherwise it will be relative to the root of the chain.
|
If None, the path will be relative to the root of the chain.
|
||||||
"""
|
"""
|
||||||
if (parent is None) or (self == top):
|
if (parent is None) or (self == top):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
@ -111,6 +144,8 @@ class Module(TorchModule):
|
||||||
|
|
||||||
|
|
||||||
class ContextModule(Module):
|
class ContextModule(Module):
|
||||||
|
"""A module containing a [`ContextProvider`][refiners.fluxion.context.ContextProvider]."""
|
||||||
|
|
||||||
# we store parent into a one element list to avoid pytorch thinking it's a submodule
|
# we store parent into a one element list to avoid pytorch thinking it's a submodule
|
||||||
_parent: "list[Chain]"
|
_parent: "list[Chain]"
|
||||||
_can_refresh_parent: bool = True # see usage in Adapter and Chain
|
_can_refresh_parent: bool = True # see usage in Adapter and Chain
|
||||||
|
@ -121,14 +156,21 @@ class ContextModule(Module):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parent(self) -> "Chain | None":
|
def parent(self) -> "Chain | None":
|
||||||
|
"""Return the module's parent, or None if module is an orphan."""
|
||||||
return self._parent[0] if self._parent else None
|
return self._parent[0] if self._parent else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ensure_parent(self) -> "Chain":
|
def ensure_parent(self) -> "Chain":
|
||||||
assert self._parent, "module is not bound to a Chain"
|
"""Return the module's parent, or raise an error if module is an orphan."""
|
||||||
return self._parent[0]
|
assert self.parent, "module does not have a parent"
|
||||||
|
return self.parent
|
||||||
|
|
||||||
|
def get_parents(self) -> "list[Chain]":
|
||||||
|
"""Recursively retrieve the module's parents."""
|
||||||
|
return self._parent + self._parent[0].get_parents() if self._parent else []
|
||||||
|
|
||||||
def _set_parent(self, parent: "Chain | None") -> None:
|
def _set_parent(self, parent: "Chain | None") -> None:
|
||||||
|
"""Set the parent of the module."""
|
||||||
if not self._can_refresh_parent:
|
if not self._can_refresh_parent:
|
||||||
return
|
return
|
||||||
if parent is None:
|
if parent is None:
|
||||||
|
@ -140,11 +182,9 @@ class ContextModule(Module):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider(self) -> ContextProvider:
|
def provider(self) -> ContextProvider:
|
||||||
|
"""Return the module's context provider."""
|
||||||
return self.ensure_parent.provider
|
return self.ensure_parent.provider
|
||||||
|
|
||||||
def get_parents(self) -> "list[Chain]":
|
|
||||||
return self._parent + self._parent[0].get_parents() if self._parent else []
|
|
||||||
|
|
||||||
def use_context(self, context_name: str) -> Context:
|
def use_context(self, context_name: str) -> Context:
|
||||||
"""Retrieve the context object from the module's context provider."""
|
"""Retrieve the context object from the module's context provider."""
|
||||||
context = self.provider.get_context(context_name)
|
context = self.provider.get_context(context_name)
|
||||||
|
@ -170,20 +210,36 @@ class ContextModule(Module):
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
|
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
|
||||||
|
"""Get the path of the module in the chain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parent: The parent of the module in the chain.
|
||||||
|
top: The top module of the chain.
|
||||||
|
If None, the path will be relative to the root of the chain.
|
||||||
|
"""
|
||||||
|
|
||||||
return super().get_path(parent=parent or self.parent, top=top)
|
return super().get_path(parent=parent or self.parent, top=top)
|
||||||
|
|
||||||
|
|
||||||
class WeightedModule(Module):
|
class WeightedModule(Module):
|
||||||
|
"""A module with a weight (Tensor) attribute."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> Device:
|
def device(self) -> Device:
|
||||||
|
"""Return the device of the module's weight."""
|
||||||
return self.weight.device
|
return self.weight.device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> DType:
|
def dtype(self) -> DType:
|
||||||
|
"""Return the dtype of the module's weight."""
|
||||||
return self.weight.dtype
|
return self.weight.dtype
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"{super().__str__().removesuffix(')')}, device={self.device}, dtype={str(self.dtype).removeprefix('torch.')})"
|
return (
|
||||||
|
f"{super().__str__().removesuffix(')')}, "
|
||||||
|
f"device={self.device}, "
|
||||||
|
f"dtype={str(self.dtype).removeprefix('torch.')})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TreeNode(TypedDict):
|
class TreeNode(TypedDict):
|
||||||
|
@ -193,6 +249,11 @@ class TreeNode(TypedDict):
|
||||||
|
|
||||||
|
|
||||||
class ModuleTree:
|
class ModuleTree:
|
||||||
|
"""A Tree of Modules.
|
||||||
|
|
||||||
|
This is useful to visualize the relations between modules.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, module: Module) -> None:
|
def __init__(self, module: Module) -> None:
|
||||||
self.root: TreeNode = self._module_to_tree(module=module)
|
self.root: TreeNode = self._module_to_tree(module=module)
|
||||||
self._fold_successive_identical(node=self.root)
|
self._fold_successive_identical(node=self.root)
|
||||||
|
@ -208,7 +269,13 @@ class ModuleTree:
|
||||||
yield child
|
yield child
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def shorten_tree_repr(cls, tree_repr: str, /, line_index: int = 0, max_lines: int = 20) -> str:
|
def shorten_tree_repr(
|
||||||
|
cls,
|
||||||
|
tree_repr: str,
|
||||||
|
/,
|
||||||
|
line_index: int = 0,
|
||||||
|
max_lines: int = 20,
|
||||||
|
) -> str:
|
||||||
"""Shorten the tree representation to a given number of lines around a given line index."""
|
"""Shorten the tree representation to a given number of lines around a given line index."""
|
||||||
lines = tree_repr.split(sep="\n")
|
lines = tree_repr.split(sep="\n")
|
||||||
start_idx = max(0, line_index - max_lines // 2)
|
start_idx = max(0, line_index - max_lines // 2)
|
||||||
|
@ -216,7 +283,14 @@ class ModuleTree:
|
||||||
return "\n".join(lines[start_idx:end_idx])
|
return "\n".join(lines[start_idx:end_idx])
|
||||||
|
|
||||||
def _generate_tree_repr(
|
def _generate_tree_repr(
|
||||||
self, node: TreeNode, /, *, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
|
self,
|
||||||
|
node: TreeNode,
|
||||||
|
/,
|
||||||
|
*,
|
||||||
|
prefix: str = "",
|
||||||
|
is_last: bool = True,
|
||||||
|
is_root: bool = True,
|
||||||
|
depth: int = -1,
|
||||||
) -> str:
|
) -> str:
|
||||||
if depth == 0 and node["children"]:
|
if depth == 0 and node["children"]:
|
||||||
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."
|
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."
|
||||||
|
|
Loading…
Reference in a new issue