mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
(doc/fluxion/adapter) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
4054421854
commit
fe53cda5e2
|
@ -8,6 +8,12 @@ TAdapter = TypeVar("TAdapter", bound="Adapter[Any]") # Self (see PEP 673)
|
||||||
|
|
||||||
|
|
||||||
class Adapter(Generic[T]):
|
class Adapter(Generic[T]):
|
||||||
|
"""Base class for adapters.
|
||||||
|
|
||||||
|
An Adapter modifies the structure of a [`Module`][refiners.fluxion.layers.Module]
|
||||||
|
(typically by adding, removing or replacing layers), to adapt it to a new task.
|
||||||
|
"""
|
||||||
|
|
||||||
# we store _target into a one element list to avoid pytorch thinking it is a submodule
|
# we store _target into a one element list to avoid pytorch thinking it is a submodule
|
||||||
_target: "list[T]"
|
_target: "list[T]"
|
||||||
|
|
||||||
|
@ -17,10 +23,20 @@ class Adapter(Generic[T]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def target(self) -> T:
|
def target(self) -> T:
|
||||||
|
"""The target of the adapter."""
|
||||||
return self._target[0]
|
return self._target[0]
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def setup_adapter(self, target: T) -> Iterator[None]:
|
def setup_adapter(self, target: T) -> Iterator[None]:
|
||||||
|
"""Setup the adapter.
|
||||||
|
|
||||||
|
This method should be called by the constructor of the adapter.
|
||||||
|
It sets the target of the adapter and ensures that the adapter
|
||||||
|
is not a submodule of the target.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: The target of the adapter.
|
||||||
|
"""
|
||||||
assert isinstance(self, fl.Chain)
|
assert isinstance(self, fl.Chain)
|
||||||
assert (not hasattr(self, "_modules")) or (
|
assert (not hasattr(self, "_modules")) or (
|
||||||
len(self) == 0
|
len(self) == 0
|
||||||
|
@ -37,6 +53,13 @@ class Adapter(Generic[T]):
|
||||||
target._can_refresh_parent = _old_can_refresh_parent
|
target._can_refresh_parent = _old_can_refresh_parent
|
||||||
|
|
||||||
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
|
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
|
||||||
|
"""Inject the adapter.
|
||||||
|
|
||||||
|
This method replaces the target of the adapter by the adapter inside the parent of the target.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parent: The parent to inject the adapter into, if the target doesn't have a parent.
|
||||||
|
"""
|
||||||
assert isinstance(self, fl.Chain)
|
assert isinstance(self, fl.Chain)
|
||||||
|
|
||||||
if (parent is None) and isinstance(self.target, fl.ContextModule):
|
if (parent is None) and isinstance(self.target, fl.ContextModule):
|
||||||
|
@ -62,6 +85,11 @@ class Adapter(Generic[T]):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def eject(self) -> None:
|
def eject(self) -> None:
|
||||||
|
"""Eject the adapter.
|
||||||
|
|
||||||
|
This method is the inverse of [`inject`][refiners.fluxion.adapters.Adapter.inject],
|
||||||
|
and should leave the target in the same state as before the injection.
|
||||||
|
"""
|
||||||
assert isinstance(self, fl.Chain)
|
assert isinstance(self, fl.Chain)
|
||||||
|
|
||||||
# In general, the "actual target" is the target.
|
# In general, the "actual target" is the target.
|
||||||
|
|
Loading…
Reference in a new issue