mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
split PositionalTokenEncoder
This commit is contained in:
parent
9d663534d1
commit
6fd5894caf
|
@ -3,40 +3,51 @@ import refiners.fluxion.layers as fl
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
|
||||||
|
|
||||||
class PositionalTokenEncoder(fl.Sum):
|
class TokenEncoder(fl.Embedding):
|
||||||
structural_attrs = ["vocabulary_size", "positional_embedding_dim"]
|
structural_attrs = ["vocabulary_size", "embedding_dim"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocabulary_size: int,
|
vocabulary_size: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
positional_embedding_dim: int,
|
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.vocabulary_size = vocabulary_size
|
self.vocabulary_size = vocabulary_size
|
||||||
self.positional_embedding_dim = positional_embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
num_embeddings=vocabulary_size,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoder(fl.Chain):
|
||||||
|
structural_attrs = ["max_sequence_length", "embedding_dim"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_sequence_length: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.max_sequence_length = max_sequence_length
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
super().__init__(
|
||||||
|
fl.Lambda(func=self.get_position_ids),
|
||||||
fl.Embedding(
|
fl.Embedding(
|
||||||
num_embeddings=vocabulary_size,
|
num_embeddings=max_sequence_length,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
fl.Chain(
|
|
||||||
fl.Lambda(func=self.get_position_ids),
|
|
||||||
fl.Embedding(
|
|
||||||
num_embeddings=positional_embedding_dim,
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def position_ids(self) -> Tensor:
|
def position_ids(self) -> Tensor:
|
||||||
return arange(end=self.positional_embedding_dim, device=self.device).reshape(1, -1)
|
return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
|
||||||
|
|
||||||
def get_position_ids(self, x: Tensor) -> Tensor:
|
def get_position_ids(self, x: Tensor) -> Tensor:
|
||||||
return self.position_ids[:, : x.shape[1]]
|
return self.position_ids[:, : x.shape[1]]
|
||||||
|
@ -147,12 +158,19 @@ class CLIPTextEncoder(fl.Chain):
|
||||||
self.use_quick_gelu = use_quick_gelu
|
self.use_quick_gelu = use_quick_gelu
|
||||||
self.tokenizer = tokenizer or CLIPTokenizer()
|
self.tokenizer = tokenizer or CLIPTokenizer()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
PositionalTokenEncoder(
|
fl.Sum(
|
||||||
vocabulary_size=vocabulary_size,
|
TokenEncoder(
|
||||||
embedding_dim=embedding_dim,
|
vocabulary_size=vocabulary_size,
|
||||||
positional_embedding_dim=positional_embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
PositionalEncoder(
|
||||||
|
max_sequence_length=positional_embedding_dim,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
*(
|
*(
|
||||||
TransformerLayer(
|
TransformerLayer(
|
||||||
|
|
Loading…
Reference in a new issue