(doc/fluxion/module) add/convert docstrings to mkdocstrings format

This commit is contained in:
Laurent 2024-02-01 22:03:46 +00:00 committed by Laureηt
parent c31da03bad
commit e3238a6af5

View file

@ -20,6 +20,8 @@ BasicType = str | float | int | bool
class Module(TorchModule): class Module(TorchModule):
"""A wrapper around [`torch.nn.Module`][torch.nn.Module]."""
_parameters: dict[str, Any] _parameters: dict[str, Any]
_buffers: dict[str, Any] _buffers: dict[str, Any]
_tag: str = "" _tag: str = ""
@ -34,14 +36,38 @@ class Module(TorchModule):
return super().__setattr__(name=name, value=value) return super().__setattr__(name=name, value=value)
def load_from_safetensors(self, tensors_path: str | Path, strict: bool = True) -> "Module": def load_from_safetensors(self, tensors_path: str | Path, strict: bool = True) -> "Module":
"""Load the module's state from a SafeTensors file.
Args:
tensors_path: The path to the SafeTensors file.
strict: Whether to raise an error if the SafeTensors's
content doesn't map perfectly to the module's state.
Returns:
The module, with its state loaded from the SafeTensors file.
"""
state_dict = load_from_safetensors(tensors_path) state_dict = load_from_safetensors(tensors_path)
self.load_state_dict(state_dict, strict=strict) self.load_state_dict(state_dict, strict=strict)
return self return self
def named_modules(self, *args: Any, **kwargs: Any) -> "Generator[tuple[str, Module], None, None]": # type: ignore def named_modules(self, *args: Any, **kwargs: Any) -> "Generator[tuple[str, Module], None, None]": # type: ignore
"""Get all the sub-modules of the module.
Returns:
An iterator over all the sub-modules of the module.
"""
return super().named_modules(*args) # type: ignore return super().named_modules(*args) # type: ignore
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
"""Move the module to the given device and cast its parameters to the given dtype.
Args:
device: The device to move the module to.
dtype: The dtype to cast the module's parameters to.
Returns:
The module, moved to the given device and cast to the given dtype.
"""
return super().to(device=device, dtype=dtype) # type: ignore return super().to(device=device, dtype=dtype) # type: ignore
def __str__(self) -> str: def __str__(self) -> str:
@ -56,13 +82,20 @@ class Module(TorchModule):
return repr(tree) return repr(tree)
def pretty_print(self, depth: int = -1) -> None: def pretty_print(self, depth: int = -1) -> None:
"""Print the module in a tree-like format.
Args:
depth: The maximum depth of the tree to print.
If negative, the whole tree is printed.
"""
tree = ModuleTree(module=self) tree = ModuleTree(module=self)
print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage] print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType | Sequence[BasicType]]: def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType | Sequence[BasicType]]:
"""Return a dictionary of basic attributes of the module. """Return a dictionary of basic attributes of the module.
Basic attributes are public attributes made of basic types (int, float, str, bool) or a sequence of basic types. Args:
init_attrs_only: Whether to only return attributes that are passed to the module's constructor.
""" """
sig = signature(obj=self.__init__) sig = signature(obj=self.__init__)
init_params = set(sig.parameters.keys()) - {"self"} init_params = set(sig.parameters.keys()) - {"self"}
@ -95,12 +128,12 @@ class Module(TorchModule):
return False return False
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str: def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
"""Helper for debugging purpose only. """Get the path of the module in the chain.
Returns the path of the module in the chain as a string. Args:
parent: The parent of the module in the chain.
If `top` is set then the path will be relative to `top`, top: The top module of the chain.
otherwise it will be relative to the root of the chain. If None, the path will be relative to the root of the chain.
""" """
if (parent is None) or (self == top): if (parent is None) or (self == top):
return self.__class__.__name__ return self.__class__.__name__
@ -111,6 +144,8 @@ class Module(TorchModule):
class ContextModule(Module): class ContextModule(Module):
"""A module containing a [`ContextProvider`][refiners.fluxion.context.ContextProvider]."""
# we store parent into a one element list to avoid pytorch thinking it's a submodule # we store parent into a one element list to avoid pytorch thinking it's a submodule
_parent: "list[Chain]" _parent: "list[Chain]"
_can_refresh_parent: bool = True # see usage in Adapter and Chain _can_refresh_parent: bool = True # see usage in Adapter and Chain
@ -121,14 +156,21 @@ class ContextModule(Module):
@property @property
def parent(self) -> "Chain | None": def parent(self) -> "Chain | None":
"""Return the module's parent, or None if module is an orphan."""
return self._parent[0] if self._parent else None return self._parent[0] if self._parent else None
@property @property
def ensure_parent(self) -> "Chain": def ensure_parent(self) -> "Chain":
assert self._parent, "module is not bound to a Chain" """Return the module's parent, or raise an error if module is an orphan."""
return self._parent[0] assert self.parent, "module does not have a parent"
return self.parent
def get_parents(self) -> "list[Chain]":
"""Recursively retrieve the module's parents."""
return self._parent + self._parent[0].get_parents() if self._parent else []
def _set_parent(self, parent: "Chain | None") -> None: def _set_parent(self, parent: "Chain | None") -> None:
"""Set the parent of the module."""
if not self._can_refresh_parent: if not self._can_refresh_parent:
return return
if parent is None: if parent is None:
@ -140,11 +182,9 @@ class ContextModule(Module):
@property @property
def provider(self) -> ContextProvider: def provider(self) -> ContextProvider:
"""Return the module's context provider."""
return self.ensure_parent.provider return self.ensure_parent.provider
def get_parents(self) -> "list[Chain]":
return self._parent + self._parent[0].get_parents() if self._parent else []
def use_context(self, context_name: str) -> Context: def use_context(self, context_name: str) -> Context:
"""Retrieve the context object from the module's context provider.""" """Retrieve the context object from the module's context provider."""
context = self.provider.get_context(context_name) context = self.provider.get_context(context_name)
@ -170,20 +210,36 @@ class ContextModule(Module):
return clone return clone
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str: def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
"""Get the path of the module in the chain.
Args:
parent: The parent of the module in the chain.
top: The top module of the chain.
If None, the path will be relative to the root of the chain.
"""
return super().get_path(parent=parent or self.parent, top=top) return super().get_path(parent=parent or self.parent, top=top)
class WeightedModule(Module): class WeightedModule(Module):
"""A module with a weight (Tensor) attribute."""
@property @property
def device(self) -> Device: def device(self) -> Device:
"""Return the device of the module's weight."""
return self.weight.device return self.weight.device
@property @property
def dtype(self) -> DType: def dtype(self) -> DType:
"""Return the dtype of the module's weight."""
return self.weight.dtype return self.weight.dtype
def __str__(self) -> str: def __str__(self) -> str:
return f"{super().__str__().removesuffix(')')}, device={self.device}, dtype={str(self.dtype).removeprefix('torch.')})" return (
f"{super().__str__().removesuffix(')')}, "
f"device={self.device}, "
f"dtype={str(self.dtype).removeprefix('torch.')})"
)
class TreeNode(TypedDict): class TreeNode(TypedDict):
@ -193,6 +249,11 @@ class TreeNode(TypedDict):
class ModuleTree: class ModuleTree:
"""A Tree of Modules.
This is useful to visualize the relations between modules.
"""
def __init__(self, module: Module) -> None: def __init__(self, module: Module) -> None:
self.root: TreeNode = self._module_to_tree(module=module) self.root: TreeNode = self._module_to_tree(module=module)
self._fold_successive_identical(node=self.root) self._fold_successive_identical(node=self.root)
@ -208,7 +269,13 @@ class ModuleTree:
yield child yield child
@classmethod @classmethod
def shorten_tree_repr(cls, tree_repr: str, /, line_index: int = 0, max_lines: int = 20) -> str: def shorten_tree_repr(
cls,
tree_repr: str,
/,
line_index: int = 0,
max_lines: int = 20,
) -> str:
"""Shorten the tree representation to a given number of lines around a given line index.""" """Shorten the tree representation to a given number of lines around a given line index."""
lines = tree_repr.split(sep="\n") lines = tree_repr.split(sep="\n")
start_idx = max(0, line_index - max_lines // 2) start_idx = max(0, line_index - max_lines // 2)
@ -216,7 +283,14 @@ class ModuleTree:
return "\n".join(lines[start_idx:end_idx]) return "\n".join(lines[start_idx:end_idx])
def _generate_tree_repr( def _generate_tree_repr(
self, node: TreeNode, /, *, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1 self,
node: TreeNode,
/,
*,
prefix: str = "",
is_last: bool = True,
is_root: bool = True,
depth: int = -1,
) -> str: ) -> str:
if depth == 0 and node["children"]: if depth == 0 and node["children"]:
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..." return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."