(doc/fluxion/adapter) add/convert docstrings to mkdocstrings format

This commit is contained in:
Laurent 2024-02-01 22:44:53 +00:00 committed by Laureηt
parent 4054421854
commit fe53cda5e2

View file

@ -8,6 +8,12 @@ TAdapter = TypeVar("TAdapter", bound="Adapter[Any]") # Self (see PEP 673)
class Adapter(Generic[T]):
"""Base class for adapters.
An Adapter modifies the structure of a [`Module`][refiners.fluxion.layers.Module]
(typically by adding, removing or replacing layers), to adapt it to a new task.
"""
# we store _target into a one element list to avoid pytorch thinking it is a submodule
_target: "list[T]"
@ -17,10 +23,20 @@ class Adapter(Generic[T]):
@property
def target(self) -> T:
"""The target of the adapter."""
return self._target[0]
@contextlib.contextmanager
def setup_adapter(self, target: T) -> Iterator[None]:
"""Setup the adapter.
This method should be called by the constructor of the adapter.
It sets the target of the adapter and ensures that the adapter
is not a submodule of the target.
Args:
target: The target of the adapter.
"""
assert isinstance(self, fl.Chain)
assert (not hasattr(self, "_modules")) or (
len(self) == 0
@ -37,6 +53,13 @@ class Adapter(Generic[T]):
target._can_refresh_parent = _old_can_refresh_parent
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
"""Inject the adapter.
This method replaces the target of the adapter by the adapter inside the parent of the target.
Args:
parent: The parent to inject the adapter into, if the target doesn't have a parent.
"""
assert isinstance(self, fl.Chain)
if (parent is None) and isinstance(self.target, fl.ContextModule):
@ -62,6 +85,11 @@ class Adapter(Generic[T]):
return self
def eject(self) -> None:
"""Eject the adapter.
This method is the inverse of [`inject`][refiners.fluxion.adapters.Adapter.inject],
and should leave the target in the same state as before the injection.
"""
assert isinstance(self, fl.Chain)
# In general, the "actual target" is the target.