mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
(doc/fluxion/embedding) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
fc824bd53d
commit
beb6dfb1c4
|
@ -1,11 +1,34 @@
|
||||||
from jaxtyping import Float, Int
|
from torch import device as Device, dtype as DType
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
|
||||||
from torch.nn import Embedding as _Embedding
|
from torch.nn import Embedding as _Embedding
|
||||||
|
|
||||||
from refiners.fluxion.layers.module import WeightedModule
|
from refiners.fluxion.layers.module import WeightedModule
|
||||||
|
|
||||||
|
|
||||||
class Embedding(_Embedding, WeightedModule): # type: ignore
|
class Embedding(_Embedding, WeightedModule):
|
||||||
|
"""Embedding layer.
|
||||||
|
|
||||||
|
This layer wraps [`torch.nn.Embedding`][torch.nn.Embedding].
|
||||||
|
|
||||||
|
Receives:
|
||||||
|
(Int[Tensor, "batch length"]):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Float[Tensor, "batch length embedding_dim"]):
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
embedding = fl.Embedding(
|
||||||
|
num_embeddings=10,
|
||||||
|
embedding_dim=128
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor = torch.randint(0, 10, (2, 10))
|
||||||
|
output = embedding(tensor)
|
||||||
|
|
||||||
|
assert output.shape == (2, 10, 128)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_embeddings: int,
|
num_embeddings: int,
|
||||||
|
@ -13,9 +36,18 @@ class Embedding(_Embedding, WeightedModule): # type: ignore
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
):
|
):
|
||||||
_Embedding.__init__( # type: ignore
|
"""Initializes the Embedding layer.
|
||||||
self, num_embeddings=num_embeddings, embedding_dim=embedding_dim, device=device, dtype=dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: Int[Tensor, "batch length"]) -> Float[Tensor, "batch length embedding_dim"]: # type: ignore
|
Args:
|
||||||
return super().forward(x)
|
num_embeddings: The number of embeddings.
|
||||||
|
embedding_dim: The dimension of the embeddings.
|
||||||
|
device: The device to use for the embedding layer.
|
||||||
|
dtype: The dtype to use for the embedding layer.
|
||||||
|
"""
|
||||||
|
_Embedding.__init__( # type: ignore
|
||||||
|
self,
|
||||||
|
num_embeddings=num_embeddings,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in a new issue