(docstrings) apply @deltheil suggestions

This commit is contained in:
Laurent 2024-02-02 09:50:15 +00:00 committed by Laureηt
parent 84d5796f08
commit 7307a3686e
6 changed files with 26 additions and 17 deletions

View file

@ -43,7 +43,8 @@ conversion = [
"tqdm>=4.62.3", "tqdm>=4.62.3",
] ]
doc = [ doc = [
"black>=24.1.1", # required by mkdocs to format the signatures # required by mkdocs to format the signatures
"black>=24.1.1",
"mkdocs-material>=9.5.6", "mkdocs-material>=9.5.6",
"mkdocstrings[python]>=0.24.0", "mkdocstrings[python]>=0.24.0",
"mkdocs-literate-nav>=0.6.1", "mkdocs-literate-nav>=0.6.1",

View file

@ -9,7 +9,7 @@ from refiners.fluxion.adapters.adapter import Adapter
class Lora(fl.Chain, ABC): class Lora(fl.Chain, ABC):
"""Low-rank approximation (LoRA) layer. """Low-Rank Adaptation (LoRA) layer.
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]: This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]:
@ -156,7 +156,7 @@ class Lora(fl.Chain, ABC):
return LoraAdapter(layer, self), parent return LoraAdapter(layer, self), parent
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None: def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
"""Load the weights of the LoRA. """Load the (pre-trained) weights of the LoRA.
Args: Args:
down_weight: The down weight. down_weight: The down weight.
@ -169,7 +169,7 @@ class Lora(fl.Chain, ABC):
class LinearLora(Lora): class LinearLora(Lora):
"""Low-rank approximation (LoRA) layer for linear layers. """Low-Rank Adaptation (LoRA) layer for linear layers.
This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers. This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers.
""" """
@ -255,7 +255,7 @@ class LinearLora(Lora):
class Conv2dLora(Lora): class Conv2dLora(Lora):
"""Low-rank approximation (LoRA) layer for 2D convolutional layers. """Low-Rank Adaptation (LoRA) layer for 2D convolutional layers.
This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers. This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers.
""" """
@ -391,12 +391,12 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
@property @property
def loras(self) -> dict[str, Lora]: def loras(self) -> dict[str, Lora]:
"""The LoRA layers.""" """The LoRA layers indexed by name."""
return {lora.name: lora for lora in self.layers(Lora)} return {lora.name: lora for lora in self.layers(Lora)}
@property @property
def scales(self) -> dict[str, float]: def scales(self) -> dict[str, float]:
"""The scales of the LoRA layers.""" """The scales of the LoRA layers indexed by names."""
return {lora.name: lora.scale for lora in self.layers(Lora)} return {lora.name: lora.scale for lora in self.layers(Lora)}
@scales.setter @scales.setter
@ -407,6 +407,9 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
def add_lora(self, lora: Lora, /) -> None: def add_lora(self, lora: Lora, /) -> None:
"""Add a LoRA layer to the adapter. """Add a LoRA layer to the adapter.
Raises:
AssertionError: If the adapter already contains a LoRA layer with the same name.
Args: Args:
lora: The LoRA layer to add. lora: The LoRA layer to add.
""" """
@ -416,6 +419,9 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
def remove_lora(self, name: str, /) -> Lora | None: def remove_lora(self, name: str, /) -> Lora | None:
"""Remove a LoRA layer from the adapter. """Remove a LoRA layer from the adapter.
Note:
If the adapter doesn't contain a LoRA layer with the given name, nothing happens and `None` is returned.
Args: Args:
name: The name of the LoRA layer to remove. name: The name of the LoRA layer to remove.
""" """

View file

@ -23,7 +23,7 @@ class ContextProvider:
self.contexts[key] = value self.contexts[key] = value
def get_context(self, key: str) -> Any: def get_context(self, key: str) -> Any:
"""Retreive a value from the context. """Retrieve a value from the context.
Args: Args:
key: The key of the context. key: The key of the context.
@ -34,7 +34,7 @@ class ContextProvider:
return self.contexts.get(key) return self.contexts.get(key)
def update_contexts(self, new_contexts: Contexts) -> None: def update_contexts(self, new_contexts: Contexts) -> None:
"""Update the contexts with new contexts. """Update or set the contexts with new contexts.
Args: Args:
new_contexts: The new contexts. new_contexts: The new contexts.

View file

@ -55,7 +55,7 @@ class ReLU(Activation):
output = relu(tensor) output = relu(tensor)
expected_output = torch.tensor([[0.0, 0.0, 1.0]]) expected_output = torch.tensor([[0.0, 0.0, 1.0]])
assert torch.allclose(output, expected_output) assert torch.equal(output, expected_output)
``` ```
""" """

View file

@ -20,7 +20,9 @@ def scaled_dot_product_attention(
) -> Float[Tensor, "batch source_sequence_length dim"]: ) -> Float[Tensor, "batch source_sequence_length dim"]:
"""Scaled Dot Product Attention. """Scaled Dot Product Attention.
Optimization depends on which pytorch backend is used. Note:
Optimization depends on which PyTorch backend is used.
See [[arXiv:1706.03762] Attention Is All You Need (Equation 1)](https://arxiv.org/abs/1706.03762) for more details. See [[arXiv:1706.03762] Attention Is All You Need (Equation 1)](https://arxiv.org/abs/1706.03762) for more details.
See also [torch.nn.functional.scaled_dot_product_attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). See also [torch.nn.functional.scaled_dot_product_attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
""" """
@ -213,7 +215,7 @@ class Attention(Chain):
which transforms the 3 inputs into Query, Key and Value which transforms the 3 inputs into Query, Key and Value
- a [`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer - a [`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer
- a [`Linear`][refiners.fluxion.layers.linear.Linear] layer, - a [`Linear`][refiners.fluxion.layers.linear.Linear] layer,
which further transforms the output of the which projects the output of the
[`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer [`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer
Receives: Receives:
@ -461,7 +463,7 @@ class SelfAttention2d(SelfAttention):
) -> Float[Tensor, "batch height*width channels"]: ) -> Float[Tensor, "batch height*width channels"]:
"""Transform a 2D Tensor into a sequence. """Transform a 2D Tensor into a sequence.
The height and width of the input Tensor are stored in the context, The height and width of the input Tensor are stored in a `"reshape"` context,
so that the output Tensor can be transformed back into a 2D Tensor in the `sequence_to_tensor_2d` method. so that the output Tensor can be transformed back into a 2D Tensor in the `sequence_to_tensor_2d` method.
""" """
height, width = x.shape[-2:] height, width = x.shape[-2:]
@ -480,7 +482,7 @@ class SelfAttention2d(SelfAttention):
) -> Float[Tensor, "batch channels height width"]: ) -> Float[Tensor, "batch channels height width"]:
"""Transform a sequence into a 2D Tensor. """Transform a sequence into a 2D Tensor.
The height and width of the output Tensor are retrieved from the context, The height and width of the output Tensor are retrieved from the `"reshape"` context,
which was set in the `tensor_2d_to_sequence` method. which was set in the `tensor_2d_to_sequence` method.
""" """
height, width = self.use_context("reshape").values() height, width = self.use_context("reshape").values()

View file

@ -17,7 +17,7 @@ class Identity(Module):
tensor = torch.randn(10, 10) tensor = torch.randn(10, 10)
output = identity(tensor) output = identity(tensor)
assert torch.allclose(tensor, output) assert torch.equal(tensor, output)
``` ```
""" """
@ -51,9 +51,9 @@ class GetArg(Module):
torch.randn(20, 20), torch.randn(20, 20),
torch.randn(30, 30), torch.randn(30, 30),
) )
output = get_arg(inputs) output = get_arg(*inputs)
assert torch.allclose(tensor[1], output) assert id(inputs[1]) == id(output)
``` ```
""" """