mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
(docstrings) apply @deltheil suggestions
This commit is contained in:
parent
84d5796f08
commit
7307a3686e
|
@ -43,7 +43,8 @@ conversion = [
|
||||||
"tqdm>=4.62.3",
|
"tqdm>=4.62.3",
|
||||||
]
|
]
|
||||||
doc = [
|
doc = [
|
||||||
"black>=24.1.1", # required by mkdocs to format the signatures
|
# required by mkdocs to format the signatures
|
||||||
|
"black>=24.1.1",
|
||||||
"mkdocs-material>=9.5.6",
|
"mkdocs-material>=9.5.6",
|
||||||
"mkdocstrings[python]>=0.24.0",
|
"mkdocstrings[python]>=0.24.0",
|
||||||
"mkdocs-literate-nav>=0.6.1",
|
"mkdocs-literate-nav>=0.6.1",
|
||||||
|
|
|
@ -9,7 +9,7 @@ from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
|
||||||
|
|
||||||
class Lora(fl.Chain, ABC):
|
class Lora(fl.Chain, ABC):
|
||||||
"""Low-rank approximation (LoRA) layer.
|
"""Low-Rank Adaptation (LoRA) layer.
|
||||||
|
|
||||||
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]:
|
This layer is composed of two [`WeightedModule`][refiners.fluxion.layers.WeightedModule]:
|
||||||
|
|
||||||
|
@ -156,7 +156,7 @@ class Lora(fl.Chain, ABC):
|
||||||
return LoraAdapter(layer, self), parent
|
return LoraAdapter(layer, self), parent
|
||||||
|
|
||||||
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
||||||
"""Load the weights of the LoRA.
|
"""Load the (pre-trained) weights of the LoRA.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
down_weight: The down weight.
|
down_weight: The down weight.
|
||||||
|
@ -169,7 +169,7 @@ class Lora(fl.Chain, ABC):
|
||||||
|
|
||||||
|
|
||||||
class LinearLora(Lora):
|
class LinearLora(Lora):
|
||||||
"""Low-rank approximation (LoRA) layer for linear layers.
|
"""Low-Rank Adaptation (LoRA) layer for linear layers.
|
||||||
|
|
||||||
This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers.
|
This layer uses two [`Linear`][refiners.fluxion.layers.Linear] layers as its down and up layers.
|
||||||
"""
|
"""
|
||||||
|
@ -255,7 +255,7 @@ class LinearLora(Lora):
|
||||||
|
|
||||||
|
|
||||||
class Conv2dLora(Lora):
|
class Conv2dLora(Lora):
|
||||||
"""Low-rank approximation (LoRA) layer for 2D convolutional layers.
|
"""Low-Rank Adaptation (LoRA) layer for 2D convolutional layers.
|
||||||
|
|
||||||
This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers.
|
This layer uses two [`Conv2d`][refiners.fluxion.layers.Conv2d] layers as its down and up layers.
|
||||||
"""
|
"""
|
||||||
|
@ -391,12 +391,12 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loras(self) -> dict[str, Lora]:
|
def loras(self) -> dict[str, Lora]:
|
||||||
"""The LoRA layers."""
|
"""The LoRA layers indexed by name."""
|
||||||
return {lora.name: lora for lora in self.layers(Lora)}
|
return {lora.name: lora for lora in self.layers(Lora)}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scales(self) -> dict[str, float]:
|
def scales(self) -> dict[str, float]:
|
||||||
"""The scales of the LoRA layers."""
|
"""The scales of the LoRA layers indexed by names."""
|
||||||
return {lora.name: lora.scale for lora in self.layers(Lora)}
|
return {lora.name: lora.scale for lora in self.layers(Lora)}
|
||||||
|
|
||||||
@scales.setter
|
@scales.setter
|
||||||
|
@ -407,6 +407,9 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||||
def add_lora(self, lora: Lora, /) -> None:
|
def add_lora(self, lora: Lora, /) -> None:
|
||||||
"""Add a LoRA layer to the adapter.
|
"""Add a LoRA layer to the adapter.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the adapter already contains a LoRA layer with the same name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lora: The LoRA layer to add.
|
lora: The LoRA layer to add.
|
||||||
"""
|
"""
|
||||||
|
@ -416,6 +419,9 @@ class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||||
def remove_lora(self, name: str, /) -> Lora | None:
|
def remove_lora(self, name: str, /) -> Lora | None:
|
||||||
"""Remove a LoRA layer from the adapter.
|
"""Remove a LoRA layer from the adapter.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the adapter doesn't contain a LoRA layer with the given name, nothing happens and `None` is returned.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The name of the LoRA layer to remove.
|
name: The name of the LoRA layer to remove.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -23,7 +23,7 @@ class ContextProvider:
|
||||||
self.contexts[key] = value
|
self.contexts[key] = value
|
||||||
|
|
||||||
def get_context(self, key: str) -> Any:
|
def get_context(self, key: str) -> Any:
|
||||||
"""Retreive a value from the context.
|
"""Retrieve a value from the context.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: The key of the context.
|
key: The key of the context.
|
||||||
|
@ -34,7 +34,7 @@ class ContextProvider:
|
||||||
return self.contexts.get(key)
|
return self.contexts.get(key)
|
||||||
|
|
||||||
def update_contexts(self, new_contexts: Contexts) -> None:
|
def update_contexts(self, new_contexts: Contexts) -> None:
|
||||||
"""Update the contexts with new contexts.
|
"""Update or set the contexts with new contexts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
new_contexts: The new contexts.
|
new_contexts: The new contexts.
|
||||||
|
|
|
@ -55,7 +55,7 @@ class ReLU(Activation):
|
||||||
output = relu(tensor)
|
output = relu(tensor)
|
||||||
|
|
||||||
expected_output = torch.tensor([[0.0, 0.0, 1.0]])
|
expected_output = torch.tensor([[0.0, 0.0, 1.0]])
|
||||||
assert torch.allclose(output, expected_output)
|
assert torch.equal(output, expected_output)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,9 @@ def scaled_dot_product_attention(
|
||||||
) -> Float[Tensor, "batch source_sequence_length dim"]:
|
) -> Float[Tensor, "batch source_sequence_length dim"]:
|
||||||
"""Scaled Dot Product Attention.
|
"""Scaled Dot Product Attention.
|
||||||
|
|
||||||
Optimization depends on which pytorch backend is used.
|
Note:
|
||||||
|
Optimization depends on which PyTorch backend is used.
|
||||||
|
|
||||||
See [[arXiv:1706.03762] Attention Is All You Need (Equation 1)](https://arxiv.org/abs/1706.03762) for more details.
|
See [[arXiv:1706.03762] Attention Is All You Need (Equation 1)](https://arxiv.org/abs/1706.03762) for more details.
|
||||||
See also [torch.nn.functional.scaled_dot_product_attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
See also [torch.nn.functional.scaled_dot_product_attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
||||||
"""
|
"""
|
||||||
|
@ -213,7 +215,7 @@ class Attention(Chain):
|
||||||
which transforms the 3 inputs into Query, Key and Value
|
which transforms the 3 inputs into Query, Key and Value
|
||||||
- a [`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer
|
- a [`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer
|
||||||
- a [`Linear`][refiners.fluxion.layers.linear.Linear] layer,
|
- a [`Linear`][refiners.fluxion.layers.linear.Linear] layer,
|
||||||
which further transforms the output of the
|
which projects the output of the
|
||||||
[`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer
|
[`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer
|
||||||
|
|
||||||
Receives:
|
Receives:
|
||||||
|
@ -461,7 +463,7 @@ class SelfAttention2d(SelfAttention):
|
||||||
) -> Float[Tensor, "batch height*width channels"]:
|
) -> Float[Tensor, "batch height*width channels"]:
|
||||||
"""Transform a 2D Tensor into a sequence.
|
"""Transform a 2D Tensor into a sequence.
|
||||||
|
|
||||||
The height and width of the input Tensor are stored in the context,
|
The height and width of the input Tensor are stored in a `"reshape"` context,
|
||||||
so that the output Tensor can be transformed back into a 2D Tensor in the `sequence_to_tensor_2d` method.
|
so that the output Tensor can be transformed back into a 2D Tensor in the `sequence_to_tensor_2d` method.
|
||||||
"""
|
"""
|
||||||
height, width = x.shape[-2:]
|
height, width = x.shape[-2:]
|
||||||
|
@ -480,7 +482,7 @@ class SelfAttention2d(SelfAttention):
|
||||||
) -> Float[Tensor, "batch channels height width"]:
|
) -> Float[Tensor, "batch channels height width"]:
|
||||||
"""Transform a sequence into a 2D Tensor.
|
"""Transform a sequence into a 2D Tensor.
|
||||||
|
|
||||||
The height and width of the output Tensor are retrieved from the context,
|
The height and width of the output Tensor are retrieved from the `"reshape"` context,
|
||||||
which was set in the `tensor_2d_to_sequence` method.
|
which was set in the `tensor_2d_to_sequence` method.
|
||||||
"""
|
"""
|
||||||
height, width = self.use_context("reshape").values()
|
height, width = self.use_context("reshape").values()
|
||||||
|
|
|
@ -17,7 +17,7 @@ class Identity(Module):
|
||||||
tensor = torch.randn(10, 10)
|
tensor = torch.randn(10, 10)
|
||||||
output = identity(tensor)
|
output = identity(tensor)
|
||||||
|
|
||||||
assert torch.allclose(tensor, output)
|
assert torch.equal(tensor, output)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -51,9 +51,9 @@ class GetArg(Module):
|
||||||
torch.randn(20, 20),
|
torch.randn(20, 20),
|
||||||
torch.randn(30, 30),
|
torch.randn(30, 30),
|
||||||
)
|
)
|
||||||
output = get_arg(inputs)
|
output = get_arg(*inputs)
|
||||||
|
|
||||||
assert torch.allclose(tensor[1], output)
|
assert id(inputs[1]) == id(output)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue