mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
update CTOR of CLIPTextEncoder with max_sequence_length
This commit is contained in:
parent
6fd5894caf
commit
4526d58cd5
|
@ -124,7 +124,7 @@ class TransformerLayer(fl.Chain):
|
|||
class CLIPTextEncoder(fl.Chain):
|
||||
structural_attrs = [
|
||||
"embedding_dim",
|
||||
"positional_embedding_dim",
|
||||
"max_sequence_length",
|
||||
"vocabulary_size",
|
||||
"num_layers",
|
||||
"num_attention_heads",
|
||||
|
@ -137,7 +137,7 @@ class CLIPTextEncoder(fl.Chain):
|
|||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 768,
|
||||
positional_embedding_dim: int = 77,
|
||||
max_sequence_length: int = 77,
|
||||
vocabulary_size: int = 49408,
|
||||
num_layers: int = 12,
|
||||
num_attention_heads: int = 12,
|
||||
|
@ -149,7 +149,7 @@ class CLIPTextEncoder(fl.Chain):
|
|||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.positional_embedding_dim = positional_embedding_dim
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.vocabulary_size = vocabulary_size
|
||||
self.num_layers = num_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
@ -166,7 +166,7 @@ class CLIPTextEncoder(fl.Chain):
|
|||
dtype=dtype,
|
||||
),
|
||||
PositionalEncoder(
|
||||
max_sequence_length=positional_embedding_dim,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dim=embedding_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
|
@ -190,7 +190,7 @@ class CLIPTextEncoder(fl.Chain):
|
|||
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
|
||||
|
||||
def encode(self, text: str) -> Tensor:
|
||||
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)
|
||||
tokens = self.tokenizer(text, sequence_length=self.max_sequence_length).to(device=self.device)
|
||||
return self(tokens)
|
||||
|
||||
@property
|
||||
|
|
|
@ -63,7 +63,7 @@ def ref_encoder(runwayml_weights_path: Path, test_device: torch.device) -> trans
|
|||
|
||||
def test_basics(ref_tokenizer: transformers.CLIPTokenizer, our_encoder: CLIPTextEncoderL):
|
||||
assert ref_tokenizer.model_max_length == 77 # type: ignore
|
||||
assert our_encoder.positional_embedding_dim == 77
|
||||
assert our_encoder.max_sequence_length == 77
|
||||
|
||||
|
||||
@pytest.fixture(params=PROMPTS)
|
||||
|
@ -86,7 +86,7 @@ def test_encoder(
|
|||
return_tensors="pt",
|
||||
).input_ids
|
||||
assert isinstance(ref_tokens, torch.Tensor)
|
||||
our_tokens = our_encoder.tokenizer(prompt, sequence_length=our_encoder.positional_embedding_dim)
|
||||
our_tokens = our_encoder.tokenizer(prompt, sequence_length=our_encoder.max_sequence_length)
|
||||
assert torch.equal(our_tokens, ref_tokens)
|
||||
|
||||
with torch.no_grad():
|
||||
|
|
Loading…
Reference in a new issue