diff --git a/src/refiners/foundationals/clip/text_encoder.py b/src/refiners/foundationals/clip/text_encoder.py index 92038ea..31b224f 100644 --- a/src/refiners/foundationals/clip/text_encoder.py +++ b/src/refiners/foundationals/clip/text_encoder.py @@ -3,40 +3,51 @@ import refiners.fluxion.layers as fl from refiners.foundationals.clip.tokenizer import CLIPTokenizer -class PositionalTokenEncoder(fl.Sum): - structural_attrs = ["vocabulary_size", "positional_embedding_dim"] +class TokenEncoder(fl.Embedding): + structural_attrs = ["vocabulary_size", "embedding_dim"] def __init__( self, vocabulary_size: int, embedding_dim: int, - positional_embedding_dim: int, device: Device | str | None = None, dtype: DType | None = None, ) -> None: self.vocabulary_size = vocabulary_size - self.positional_embedding_dim = positional_embedding_dim + self.embedding_dim = embedding_dim super().__init__( + num_embeddings=vocabulary_size, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ) + + +class PositionalEncoder(fl.Chain): + structural_attrs = ["max_sequence_length", "embedding_dim"] + + def __init__( + self, + max_sequence_length: int, + embedding_dim: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.max_sequence_length = max_sequence_length + self.embedding_dim = embedding_dim + super().__init__( + fl.Lambda(func=self.get_position_ids), fl.Embedding( - num_embeddings=vocabulary_size, + num_embeddings=max_sequence_length, embedding_dim=embedding_dim, device=device, dtype=dtype, ), - fl.Chain( - fl.Lambda(func=self.get_position_ids), - fl.Embedding( - num_embeddings=positional_embedding_dim, - embedding_dim=embedding_dim, - device=device, - dtype=dtype, - ), - ), ) @property def position_ids(self) -> Tensor: - return arange(end=self.positional_embedding_dim, device=self.device).reshape(1, -1) + return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1) def get_position_ids(self, x: Tensor) -> Tensor: return self.position_ids[:, : x.shape[1]] @@ -147,12 +158,19 @@ class CLIPTextEncoder(fl.Chain): self.use_quick_gelu = use_quick_gelu self.tokenizer = tokenizer or CLIPTokenizer() super().__init__( - PositionalTokenEncoder( - vocabulary_size=vocabulary_size, - embedding_dim=embedding_dim, - positional_embedding_dim=positional_embedding_dim, - device=device, - dtype=dtype, + fl.Sum( + TokenEncoder( + vocabulary_size=vocabulary_size, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ), + PositionalEncoder( + max_sequence_length=positional_embedding_dim, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ), ), *( TransformerLayer(