update CTOR of CLIPTextEncoder with max_sequence_length

This commit is contained in:
limiteinductive 2023-08-17 10:50:54 +02:00 committed by Benjamin Trom
parent 6fd5894caf
commit 4526d58cd5
2 changed files with 7 additions and 7 deletions

View file

@ -124,7 +124,7 @@ class TransformerLayer(fl.Chain):
class CLIPTextEncoder(fl.Chain):
structural_attrs = [
"embedding_dim",
"positional_embedding_dim",
"max_sequence_length",
"vocabulary_size",
"num_layers",
"num_attention_heads",
@ -137,7 +137,7 @@ class CLIPTextEncoder(fl.Chain):
def __init__(
self,
embedding_dim: int = 768,
positional_embedding_dim: int = 77,
max_sequence_length: int = 77,
vocabulary_size: int = 49408,
num_layers: int = 12,
num_attention_heads: int = 12,
@ -149,7 +149,7 @@ class CLIPTextEncoder(fl.Chain):
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.positional_embedding_dim = positional_embedding_dim
self.max_sequence_length = max_sequence_length
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
@ -166,7 +166,7 @@ class CLIPTextEncoder(fl.Chain):
dtype=dtype,
),
PositionalEncoder(
max_sequence_length=positional_embedding_dim,
max_sequence_length=max_sequence_length,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
@ -190,7 +190,7 @@ class CLIPTextEncoder(fl.Chain):
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
def encode(self, text: str) -> Tensor:
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)
tokens = self.tokenizer(text, sequence_length=self.max_sequence_length).to(device=self.device)
return self(tokens)
@property

View file

@ -63,7 +63,7 @@ def ref_encoder(runwayml_weights_path: Path, test_device: torch.device) -> trans
def test_basics(ref_tokenizer: transformers.CLIPTokenizer, our_encoder: CLIPTextEncoderL):
assert ref_tokenizer.model_max_length == 77 # type: ignore
assert our_encoder.positional_embedding_dim == 77
assert our_encoder.max_sequence_length == 77
@pytest.fixture(params=PROMPTS)
@ -86,7 +86,7 @@ def test_encoder(
return_tensors="pt",
).input_ids
assert isinstance(ref_tokens, torch.Tensor)
our_tokens = our_encoder.tokenizer(prompt, sequence_length=our_encoder.positional_embedding_dim)
our_tokens = our_encoder.tokenizer(prompt, sequence_length=our_encoder.max_sequence_length)
assert torch.equal(our_tokens, ref_tokens)
with torch.no_grad():