split PositionalTokenEncoder

This commit is contained in:
limiteinductive 2023-08-16 15:20:21 +02:00 committed by Benjamin Trom
parent 9d663534d1
commit 6fd5894caf

View file

@ -3,40 +3,51 @@ import refiners.fluxion.layers as fl
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
class PositionalTokenEncoder(fl.Sum):
structural_attrs = ["vocabulary_size", "positional_embedding_dim"]
class TokenEncoder(fl.Embedding):
structural_attrs = ["vocabulary_size", "embedding_dim"]
def __init__(
self,
vocabulary_size: int,
embedding_dim: int,
positional_embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.vocabulary_size = vocabulary_size
self.positional_embedding_dim = positional_embedding_dim
self.embedding_dim = embedding_dim
super().__init__(
fl.Embedding(
num_embeddings=vocabulary_size,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
fl.Chain(
)
class PositionalEncoder(fl.Chain):
structural_attrs = ["max_sequence_length", "embedding_dim"]
def __init__(
self,
max_sequence_length: int,
embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.max_sequence_length = max_sequence_length
self.embedding_dim = embedding_dim
super().__init__(
fl.Lambda(func=self.get_position_ids),
fl.Embedding(
num_embeddings=positional_embedding_dim,
num_embeddings=max_sequence_length,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
),
)
@property
def position_ids(self) -> Tensor:
return arange(end=self.positional_embedding_dim, device=self.device).reshape(1, -1)
return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
def get_position_ids(self, x: Tensor) -> Tensor:
return self.position_ids[:, : x.shape[1]]
@ -147,13 +158,20 @@ class CLIPTextEncoder(fl.Chain):
self.use_quick_gelu = use_quick_gelu
self.tokenizer = tokenizer or CLIPTokenizer()
super().__init__(
PositionalTokenEncoder(
fl.Sum(
TokenEncoder(
vocabulary_size=vocabulary_size,
embedding_dim=embedding_dim,
positional_embedding_dim=positional_embedding_dim,
device=device,
dtype=dtype,
),
PositionalEncoder(
max_sequence_length=positional_embedding_dim,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
),
*(
TransformerLayer(
embedding_dim=embedding_dim,