mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
(doc/fluxion/module) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
c31da03bad
commit
e3238a6af5
|
@ -20,6 +20,8 @@ BasicType = str | float | int | bool
|
|||
|
||||
|
||||
class Module(TorchModule):
|
||||
"""A wrapper around [`torch.nn.Module`][torch.nn.Module]."""
|
||||
|
||||
_parameters: dict[str, Any]
|
||||
_buffers: dict[str, Any]
|
||||
_tag: str = ""
|
||||
|
@ -34,14 +36,38 @@ class Module(TorchModule):
|
|||
return super().__setattr__(name=name, value=value)
|
||||
|
||||
def load_from_safetensors(self, tensors_path: str | Path, strict: bool = True) -> "Module":
|
||||
"""Load the module's state from a SafeTensors file.
|
||||
|
||||
Args:
|
||||
tensors_path: The path to the SafeTensors file.
|
||||
strict: Whether to raise an error if the SafeTensors's
|
||||
content doesn't map perfectly to the module's state.
|
||||
|
||||
Returns:
|
||||
The module, with its state loaded from the SafeTensors file.
|
||||
"""
|
||||
state_dict = load_from_safetensors(tensors_path)
|
||||
self.load_state_dict(state_dict, strict=strict)
|
||||
return self
|
||||
|
||||
def named_modules(self, *args: Any, **kwargs: Any) -> "Generator[tuple[str, Module], None, None]": # type: ignore
|
||||
"""Get all the sub-modules of the module.
|
||||
|
||||
Returns:
|
||||
An iterator over all the sub-modules of the module.
|
||||
"""
|
||||
return super().named_modules(*args) # type: ignore
|
||||
|
||||
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
|
||||
"""Move the module to the given device and cast its parameters to the given dtype.
|
||||
|
||||
Args:
|
||||
device: The device to move the module to.
|
||||
dtype: The dtype to cast the module's parameters to.
|
||||
|
||||
Returns:
|
||||
The module, moved to the given device and cast to the given dtype.
|
||||
"""
|
||||
return super().to(device=device, dtype=dtype) # type: ignore
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
@ -56,13 +82,20 @@ class Module(TorchModule):
|
|||
return repr(tree)
|
||||
|
||||
def pretty_print(self, depth: int = -1) -> None:
|
||||
"""Print the module in a tree-like format.
|
||||
|
||||
Args:
|
||||
depth: The maximum depth of the tree to print.
|
||||
If negative, the whole tree is printed.
|
||||
"""
|
||||
tree = ModuleTree(module=self)
|
||||
print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]
|
||||
|
||||
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType | Sequence[BasicType]]:
|
||||
"""Return a dictionary of basic attributes of the module.
|
||||
|
||||
Basic attributes are public attributes made of basic types (int, float, str, bool) or a sequence of basic types.
|
||||
Args:
|
||||
init_attrs_only: Whether to only return attributes that are passed to the module's constructor.
|
||||
"""
|
||||
sig = signature(obj=self.__init__)
|
||||
init_params = set(sig.parameters.keys()) - {"self"}
|
||||
|
@ -95,12 +128,12 @@ class Module(TorchModule):
|
|||
return False
|
||||
|
||||
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
|
||||
"""Helper for debugging purpose only.
|
||||
"""Get the path of the module in the chain.
|
||||
|
||||
Returns the path of the module in the chain as a string.
|
||||
|
||||
If `top` is set then the path will be relative to `top`,
|
||||
otherwise it will be relative to the root of the chain.
|
||||
Args:
|
||||
parent: The parent of the module in the chain.
|
||||
top: The top module of the chain.
|
||||
If None, the path will be relative to the root of the chain.
|
||||
"""
|
||||
if (parent is None) or (self == top):
|
||||
return self.__class__.__name__
|
||||
|
@ -111,6 +144,8 @@ class Module(TorchModule):
|
|||
|
||||
|
||||
class ContextModule(Module):
|
||||
"""A module containing a [`ContextProvider`][refiners.fluxion.context.ContextProvider]."""
|
||||
|
||||
# we store parent into a one element list to avoid pytorch thinking it's a submodule
|
||||
_parent: "list[Chain]"
|
||||
_can_refresh_parent: bool = True # see usage in Adapter and Chain
|
||||
|
@ -121,14 +156,21 @@ class ContextModule(Module):
|
|||
|
||||
@property
|
||||
def parent(self) -> "Chain | None":
|
||||
"""Return the module's parent, or None if module is an orphan."""
|
||||
return self._parent[0] if self._parent else None
|
||||
|
||||
@property
|
||||
def ensure_parent(self) -> "Chain":
|
||||
assert self._parent, "module is not bound to a Chain"
|
||||
return self._parent[0]
|
||||
"""Return the module's parent, or raise an error if module is an orphan."""
|
||||
assert self.parent, "module does not have a parent"
|
||||
return self.parent
|
||||
|
||||
def get_parents(self) -> "list[Chain]":
|
||||
"""Recursively retrieve the module's parents."""
|
||||
return self._parent + self._parent[0].get_parents() if self._parent else []
|
||||
|
||||
def _set_parent(self, parent: "Chain | None") -> None:
|
||||
"""Set the parent of the module."""
|
||||
if not self._can_refresh_parent:
|
||||
return
|
||||
if parent is None:
|
||||
|
@ -140,11 +182,9 @@ class ContextModule(Module):
|
|||
|
||||
@property
|
||||
def provider(self) -> ContextProvider:
|
||||
"""Return the module's context provider."""
|
||||
return self.ensure_parent.provider
|
||||
|
||||
def get_parents(self) -> "list[Chain]":
|
||||
return self._parent + self._parent[0].get_parents() if self._parent else []
|
||||
|
||||
def use_context(self, context_name: str) -> Context:
|
||||
"""Retrieve the context object from the module's context provider."""
|
||||
context = self.provider.get_context(context_name)
|
||||
|
@ -170,20 +210,36 @@ class ContextModule(Module):
|
|||
return clone
|
||||
|
||||
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
|
||||
"""Get the path of the module in the chain.
|
||||
|
||||
Args:
|
||||
parent: The parent of the module in the chain.
|
||||
top: The top module of the chain.
|
||||
If None, the path will be relative to the root of the chain.
|
||||
"""
|
||||
|
||||
return super().get_path(parent=parent or self.parent, top=top)
|
||||
|
||||
|
||||
class WeightedModule(Module):
|
||||
"""A module with a weight (Tensor) attribute."""
|
||||
|
||||
@property
|
||||
def device(self) -> Device:
|
||||
"""Return the device of the module's weight."""
|
||||
return self.weight.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
"""Return the dtype of the module's weight."""
|
||||
return self.weight.dtype
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{super().__str__().removesuffix(')')}, device={self.device}, dtype={str(self.dtype).removeprefix('torch.')})"
|
||||
return (
|
||||
f"{super().__str__().removesuffix(')')}, "
|
||||
f"device={self.device}, "
|
||||
f"dtype={str(self.dtype).removeprefix('torch.')})"
|
||||
)
|
||||
|
||||
|
||||
class TreeNode(TypedDict):
|
||||
|
@ -193,6 +249,11 @@ class TreeNode(TypedDict):
|
|||
|
||||
|
||||
class ModuleTree:
|
||||
"""A Tree of Modules.
|
||||
|
||||
This is useful to visualize the relations between modules.
|
||||
"""
|
||||
|
||||
def __init__(self, module: Module) -> None:
|
||||
self.root: TreeNode = self._module_to_tree(module=module)
|
||||
self._fold_successive_identical(node=self.root)
|
||||
|
@ -208,7 +269,13 @@ class ModuleTree:
|
|||
yield child
|
||||
|
||||
@classmethod
|
||||
def shorten_tree_repr(cls, tree_repr: str, /, line_index: int = 0, max_lines: int = 20) -> str:
|
||||
def shorten_tree_repr(
|
||||
cls,
|
||||
tree_repr: str,
|
||||
/,
|
||||
line_index: int = 0,
|
||||
max_lines: int = 20,
|
||||
) -> str:
|
||||
"""Shorten the tree representation to a given number of lines around a given line index."""
|
||||
lines = tree_repr.split(sep="\n")
|
||||
start_idx = max(0, line_index - max_lines // 2)
|
||||
|
@ -216,7 +283,14 @@ class ModuleTree:
|
|||
return "\n".join(lines[start_idx:end_idx])
|
||||
|
||||
def _generate_tree_repr(
|
||||
self, node: TreeNode, /, *, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
|
||||
self,
|
||||
node: TreeNode,
|
||||
/,
|
||||
*,
|
||||
prefix: str = "",
|
||||
is_last: bool = True,
|
||||
is_root: bool = True,
|
||||
depth: int = -1,
|
||||
) -> str:
|
||||
if depth == 0 and node["children"]:
|
||||
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."
|
||||
|
|
Loading…
Reference in a new issue