From beb6dfb1c4a79d776bd45af6cc5268db94e21f1b Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 1 Feb 2024 21:51:43 +0000 Subject: [PATCH] (doc/fluxion/embedding) add/convert docstrings to mkdocstrings format --- src/refiners/fluxion/layers/embedding.py | 48 ++++++++++++++++++++---- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/src/refiners/fluxion/layers/embedding.py b/src/refiners/fluxion/layers/embedding.py index 81eb6b1..e5399f8 100644 --- a/src/refiners/fluxion/layers/embedding.py +++ b/src/refiners/fluxion/layers/embedding.py @@ -1,11 +1,34 @@ -from jaxtyping import Float, Int -from torch import Tensor, device as Device, dtype as DType +from torch import device as Device, dtype as DType from torch.nn import Embedding as _Embedding from refiners.fluxion.layers.module import WeightedModule -class Embedding(_Embedding, WeightedModule): # type: ignore +class Embedding(_Embedding, WeightedModule): + """Embedding layer. + + This layer wraps [`torch.nn.Embedding`][torch.nn.Embedding]. + + Receives: + (Int[Tensor, "batch length"]): + + Returns: + (Float[Tensor, "batch length embedding_dim"]): + + Example: + ```py + embedding = fl.Embedding( + num_embeddings=10, + embedding_dim=128 + ) + + tensor = torch.randint(0, 10, (2, 10)) + output = embedding(tensor) + + assert output.shape == (2, 10, 128) + ``` + """ + def __init__( self, num_embeddings: int, @@ -13,9 +36,18 @@ class Embedding(_Embedding, WeightedModule): # type: ignore device: Device | str | None = None, dtype: DType | None = None, ): - _Embedding.__init__( # type: ignore - self, num_embeddings=num_embeddings, embedding_dim=embedding_dim, device=device, dtype=dtype - ) + """Initializes the Embedding layer. - def forward(self, x: Int[Tensor, "batch length"]) -> Float[Tensor, "batch length embedding_dim"]: # type: ignore - return super().forward(x) + Args: + num_embeddings: The number of embeddings. + embedding_dim: The dimension of the embeddings. + device: The device to use for the embedding layer. + dtype: The dtype to use for the embedding layer. + """ + _Embedding.__init__( # type: ignore + self, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + )