mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
(doc/fluxion/embedding) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
fc824bd53d
commit
beb6dfb1c4
|
@ -1,11 +1,34 @@
|
|||
from jaxtyping import Float, Int
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
from torch import device as Device, dtype as DType
|
||||
from torch.nn import Embedding as _Embedding
|
||||
|
||||
from refiners.fluxion.layers.module import WeightedModule
|
||||
|
||||
|
||||
class Embedding(_Embedding, WeightedModule): # type: ignore
|
||||
class Embedding(_Embedding, WeightedModule):
|
||||
"""Embedding layer.
|
||||
|
||||
This layer wraps [`torch.nn.Embedding`][torch.nn.Embedding].
|
||||
|
||||
Receives:
|
||||
(Int[Tensor, "batch length"]):
|
||||
|
||||
Returns:
|
||||
(Float[Tensor, "batch length embedding_dim"]):
|
||||
|
||||
Example:
|
||||
```py
|
||||
embedding = fl.Embedding(
|
||||
num_embeddings=10,
|
||||
embedding_dim=128
|
||||
)
|
||||
|
||||
tensor = torch.randint(0, 10, (2, 10))
|
||||
output = embedding(tensor)
|
||||
|
||||
assert output.shape == (2, 10, 128)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
|
@ -13,9 +36,18 @@ class Embedding(_Embedding, WeightedModule): # type: ignore
|
|||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
):
|
||||
_Embedding.__init__( # type: ignore
|
||||
self, num_embeddings=num_embeddings, embedding_dim=embedding_dim, device=device, dtype=dtype
|
||||
)
|
||||
"""Initializes the Embedding layer.
|
||||
|
||||
def forward(self, x: Int[Tensor, "batch length"]) -> Float[Tensor, "batch length embedding_dim"]: # type: ignore
|
||||
return super().forward(x)
|
||||
Args:
|
||||
num_embeddings: The number of embeddings.
|
||||
embedding_dim: The dimension of the embeddings.
|
||||
device: The device to use for the embedding layer.
|
||||
dtype: The dtype to use for the embedding layer.
|
||||
"""
|
||||
_Embedding.__init__( # type: ignore
|
||||
self,
|
||||
num_embeddings=num_embeddings,
|
||||
embedding_dim=embedding_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue