split PositionalTokenEncoder

This commit is contained in:
limiteinductive 2023-08-16 15:20:21 +02:00 committed by Benjamin Trom
parent 9d663534d1
commit 6fd5894caf

View file

@ -3,40 +3,51 @@ import refiners.fluxion.layers as fl
from refiners.foundationals.clip.tokenizer import CLIPTokenizer from refiners.foundationals.clip.tokenizer import CLIPTokenizer
class PositionalTokenEncoder(fl.Sum): class TokenEncoder(fl.Embedding):
structural_attrs = ["vocabulary_size", "positional_embedding_dim"] structural_attrs = ["vocabulary_size", "embedding_dim"]
def __init__( def __init__(
self, self,
vocabulary_size: int, vocabulary_size: int,
embedding_dim: int, embedding_dim: int,
positional_embedding_dim: int,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
self.vocabulary_size = vocabulary_size self.vocabulary_size = vocabulary_size
self.positional_embedding_dim = positional_embedding_dim self.embedding_dim = embedding_dim
super().__init__( super().__init__(
num_embeddings=vocabulary_size,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
)
class PositionalEncoder(fl.Chain):
structural_attrs = ["max_sequence_length", "embedding_dim"]
def __init__(
self,
max_sequence_length: int,
embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.max_sequence_length = max_sequence_length
self.embedding_dim = embedding_dim
super().__init__(
fl.Lambda(func=self.get_position_ids),
fl.Embedding( fl.Embedding(
num_embeddings=vocabulary_size, num_embeddings=max_sequence_length,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),
fl.Chain(
fl.Lambda(func=self.get_position_ids),
fl.Embedding(
num_embeddings=positional_embedding_dim,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
),
) )
@property @property
def position_ids(self) -> Tensor: def position_ids(self) -> Tensor:
return arange(end=self.positional_embedding_dim, device=self.device).reshape(1, -1) return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
def get_position_ids(self, x: Tensor) -> Tensor: def get_position_ids(self, x: Tensor) -> Tensor:
return self.position_ids[:, : x.shape[1]] return self.position_ids[:, : x.shape[1]]
@ -147,12 +158,19 @@ class CLIPTextEncoder(fl.Chain):
self.use_quick_gelu = use_quick_gelu self.use_quick_gelu = use_quick_gelu
self.tokenizer = tokenizer or CLIPTokenizer() self.tokenizer = tokenizer or CLIPTokenizer()
super().__init__( super().__init__(
PositionalTokenEncoder( fl.Sum(
vocabulary_size=vocabulary_size, TokenEncoder(
embedding_dim=embedding_dim, vocabulary_size=vocabulary_size,
positional_embedding_dim=positional_embedding_dim, embedding_dim=embedding_dim,
device=device, device=device,
dtype=dtype, dtype=dtype,
),
PositionalEncoder(
max_sequence_length=positional_embedding_dim,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
), ),
*( *(
TransformerLayer( TransformerLayer(