(doc/fluxion/embedding) add/convert docstrings to mkdocstrings format

This commit is contained in:
Laurent 2024-02-01 21:51:43 +00:00 committed by Laureηt
parent fc824bd53d
commit beb6dfb1c4

View file

@ -1,11 +1,34 @@
from jaxtyping import Float, Int
from torch import Tensor, device as Device, dtype as DType
from torch import device as Device, dtype as DType
from torch.nn import Embedding as _Embedding
from refiners.fluxion.layers.module import WeightedModule
class Embedding(_Embedding, WeightedModule): # type: ignore
class Embedding(_Embedding, WeightedModule):
"""Embedding layer.
This layer wraps [`torch.nn.Embedding`][torch.nn.Embedding].
Receives:
(Int[Tensor, "batch length"]):
Returns:
(Float[Tensor, "batch length embedding_dim"]):
Example:
```py
embedding = fl.Embedding(
num_embeddings=10,
embedding_dim=128
)
tensor = torch.randint(0, 10, (2, 10))
output = embedding(tensor)
assert output.shape == (2, 10, 128)
```
"""
def __init__(
self,
num_embeddings: int,
@ -13,9 +36,18 @@ class Embedding(_Embedding, WeightedModule): # type: ignore
device: Device | str | None = None,
dtype: DType | None = None,
):
_Embedding.__init__( # type: ignore
self, num_embeddings=num_embeddings, embedding_dim=embedding_dim, device=device, dtype=dtype
)
"""Initializes the Embedding layer.
def forward(self, x: Int[Tensor, "batch length"]) -> Float[Tensor, "batch length embedding_dim"]: # type: ignore
return super().forward(x)
Args:
num_embeddings: The number of embeddings.
embedding_dim: The dimension of the embeddings.
device: The device to use for the embedding layer.
dtype: The dtype to use for the embedding layer.
"""
_Embedding.__init__( # type: ignore
self,
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
)