(doc/fluxion/module) add/convert docstrings to mkdocstrings format

This commit is contained in:
Laurent 2024-02-01 22:03:46 +00:00 committed by Laureηt
parent c31da03bad
commit e3238a6af5

View file

@ -20,6 +20,8 @@ BasicType = str | float | int | bool
class Module(TorchModule):
"""A wrapper around [`torch.nn.Module`][torch.nn.Module]."""
_parameters: dict[str, Any]
_buffers: dict[str, Any]
_tag: str = ""
@ -34,14 +36,38 @@ class Module(TorchModule):
return super().__setattr__(name=name, value=value)
def load_from_safetensors(self, tensors_path: str | Path, strict: bool = True) -> "Module":
"""Load the module's state from a SafeTensors file.
Args:
tensors_path: The path to the SafeTensors file.
strict: Whether to raise an error if the SafeTensors's
content doesn't map perfectly to the module's state.
Returns:
The module, with its state loaded from the SafeTensors file.
"""
state_dict = load_from_safetensors(tensors_path)
self.load_state_dict(state_dict, strict=strict)
return self
def named_modules(self, *args: Any, **kwargs: Any) -> "Generator[tuple[str, Module], None, None]": # type: ignore
"""Get all the sub-modules of the module.
Returns:
An iterator over all the sub-modules of the module.
"""
return super().named_modules(*args) # type: ignore
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
"""Move the module to the given device and cast its parameters to the given dtype.
Args:
device: The device to move the module to.
dtype: The dtype to cast the module's parameters to.
Returns:
The module, moved to the given device and cast to the given dtype.
"""
return super().to(device=device, dtype=dtype) # type: ignore
def __str__(self) -> str:
@ -56,13 +82,20 @@ class Module(TorchModule):
return repr(tree)
def pretty_print(self, depth: int = -1) -> None:
"""Print the module in a tree-like format.
Args:
depth: The maximum depth of the tree to print.
If negative, the whole tree is printed.
"""
tree = ModuleTree(module=self)
print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType | Sequence[BasicType]]:
"""Return a dictionary of basic attributes of the module.
Basic attributes are public attributes made of basic types (int, float, str, bool) or a sequence of basic types.
Args:
init_attrs_only: Whether to only return attributes that are passed to the module's constructor.
"""
sig = signature(obj=self.__init__)
init_params = set(sig.parameters.keys()) - {"self"}
@ -95,12 +128,12 @@ class Module(TorchModule):
return False
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
"""Helper for debugging purpose only.
"""Get the path of the module in the chain.
Returns the path of the module in the chain as a string.
If `top` is set then the path will be relative to `top`,
otherwise it will be relative to the root of the chain.
Args:
parent: The parent of the module in the chain.
top: The top module of the chain.
If None, the path will be relative to the root of the chain.
"""
if (parent is None) or (self == top):
return self.__class__.__name__
@ -111,6 +144,8 @@ class Module(TorchModule):
class ContextModule(Module):
"""A module containing a [`ContextProvider`][refiners.fluxion.context.ContextProvider]."""
# we store parent into a one element list to avoid pytorch thinking it's a submodule
_parent: "list[Chain]"
_can_refresh_parent: bool = True # see usage in Adapter and Chain
@ -121,14 +156,21 @@ class ContextModule(Module):
@property
def parent(self) -> "Chain | None":
"""Return the module's parent, or None if module is an orphan."""
return self._parent[0] if self._parent else None
@property
def ensure_parent(self) -> "Chain":
assert self._parent, "module is not bound to a Chain"
return self._parent[0]
"""Return the module's parent, or raise an error if module is an orphan."""
assert self.parent, "module does not have a parent"
return self.parent
def get_parents(self) -> "list[Chain]":
"""Recursively retrieve the module's parents."""
return self._parent + self._parent[0].get_parents() if self._parent else []
def _set_parent(self, parent: "Chain | None") -> None:
"""Set the parent of the module."""
if not self._can_refresh_parent:
return
if parent is None:
@ -140,11 +182,9 @@ class ContextModule(Module):
@property
def provider(self) -> ContextProvider:
"""Return the module's context provider."""
return self.ensure_parent.provider
def get_parents(self) -> "list[Chain]":
return self._parent + self._parent[0].get_parents() if self._parent else []
def use_context(self, context_name: str) -> Context:
"""Retrieve the context object from the module's context provider."""
context = self.provider.get_context(context_name)
@ -170,20 +210,36 @@ class ContextModule(Module):
return clone
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
"""Get the path of the module in the chain.
Args:
parent: The parent of the module in the chain.
top: The top module of the chain.
If None, the path will be relative to the root of the chain.
"""
return super().get_path(parent=parent or self.parent, top=top)
class WeightedModule(Module):
"""A module with a weight (Tensor) attribute."""
@property
def device(self) -> Device:
"""Return the device of the module's weight."""
return self.weight.device
@property
def dtype(self) -> DType:
"""Return the dtype of the module's weight."""
return self.weight.dtype
def __str__(self) -> str:
return f"{super().__str__().removesuffix(')')}, device={self.device}, dtype={str(self.dtype).removeprefix('torch.')})"
return (
f"{super().__str__().removesuffix(')')}, "
f"device={self.device}, "
f"dtype={str(self.dtype).removeprefix('torch.')})"
)
class TreeNode(TypedDict):
@ -193,6 +249,11 @@ class TreeNode(TypedDict):
class ModuleTree:
"""A Tree of Modules.
This is useful to visualize the relations between modules.
"""
def __init__(self, module: Module) -> None:
self.root: TreeNode = self._module_to_tree(module=module)
self._fold_successive_identical(node=self.root)
@ -208,7 +269,13 @@ class ModuleTree:
yield child
@classmethod
def shorten_tree_repr(cls, tree_repr: str, /, line_index: int = 0, max_lines: int = 20) -> str:
def shorten_tree_repr(
cls,
tree_repr: str,
/,
line_index: int = 0,
max_lines: int = 20,
) -> str:
"""Shorten the tree representation to a given number of lines around a given line index."""
lines = tree_repr.split(sep="\n")
start_idx = max(0, line_index - max_lines // 2)
@ -216,7 +283,14 @@ class ModuleTree:
return "\n".join(lines[start_idx:end_idx])
def _generate_tree_repr(
self, node: TreeNode, /, *, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
self,
node: TreeNode,
/,
*,
prefix: str = "",
is_last: bool = True,
is_root: bool = True,
depth: int = -1,
) -> str:
if depth == 0 and node["children"]:
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."