mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
update CTOR of CLIPTextEncoder with max_sequence_length
This commit is contained in:
parent
6fd5894caf
commit
4526d58cd5
|
@ -124,7 +124,7 @@ class TransformerLayer(fl.Chain):
|
||||||
class CLIPTextEncoder(fl.Chain):
|
class CLIPTextEncoder(fl.Chain):
|
||||||
structural_attrs = [
|
structural_attrs = [
|
||||||
"embedding_dim",
|
"embedding_dim",
|
||||||
"positional_embedding_dim",
|
"max_sequence_length",
|
||||||
"vocabulary_size",
|
"vocabulary_size",
|
||||||
"num_layers",
|
"num_layers",
|
||||||
"num_attention_heads",
|
"num_attention_heads",
|
||||||
|
@ -137,7 +137,7 @@ class CLIPTextEncoder(fl.Chain):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_dim: int = 768,
|
embedding_dim: int = 768,
|
||||||
positional_embedding_dim: int = 77,
|
max_sequence_length: int = 77,
|
||||||
vocabulary_size: int = 49408,
|
vocabulary_size: int = 49408,
|
||||||
num_layers: int = 12,
|
num_layers: int = 12,
|
||||||
num_attention_heads: int = 12,
|
num_attention_heads: int = 12,
|
||||||
|
@ -149,7 +149,7 @@ class CLIPTextEncoder(fl.Chain):
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.positional_embedding_dim = positional_embedding_dim
|
self.max_sequence_length = max_sequence_length
|
||||||
self.vocabulary_size = vocabulary_size
|
self.vocabulary_size = vocabulary_size
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
|
@ -166,7 +166,7 @@ class CLIPTextEncoder(fl.Chain):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
PositionalEncoder(
|
PositionalEncoder(
|
||||||
max_sequence_length=positional_embedding_dim,
|
max_sequence_length=max_sequence_length,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -190,7 +190,7 @@ class CLIPTextEncoder(fl.Chain):
|
||||||
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
|
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
|
||||||
|
|
||||||
def encode(self, text: str) -> Tensor:
|
def encode(self, text: str) -> Tensor:
|
||||||
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)
|
tokens = self.tokenizer(text, sequence_length=self.max_sequence_length).to(device=self.device)
|
||||||
return self(tokens)
|
return self(tokens)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -63,7 +63,7 @@ def ref_encoder(runwayml_weights_path: Path, test_device: torch.device) -> trans
|
||||||
|
|
||||||
def test_basics(ref_tokenizer: transformers.CLIPTokenizer, our_encoder: CLIPTextEncoderL):
|
def test_basics(ref_tokenizer: transformers.CLIPTokenizer, our_encoder: CLIPTextEncoderL):
|
||||||
assert ref_tokenizer.model_max_length == 77 # type: ignore
|
assert ref_tokenizer.model_max_length == 77 # type: ignore
|
||||||
assert our_encoder.positional_embedding_dim == 77
|
assert our_encoder.max_sequence_length == 77
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=PROMPTS)
|
@pytest.fixture(params=PROMPTS)
|
||||||
|
@ -86,7 +86,7 @@ def test_encoder(
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
).input_ids
|
).input_ids
|
||||||
assert isinstance(ref_tokens, torch.Tensor)
|
assert isinstance(ref_tokens, torch.Tensor)
|
||||||
our_tokens = our_encoder.tokenizer(prompt, sequence_length=our_encoder.positional_embedding_dim)
|
our_tokens = our_encoder.tokenizer(prompt, sequence_length=our_encoder.max_sequence_length)
|
||||||
assert torch.equal(our_tokens, ref_tokens)
|
assert torch.equal(our_tokens, ref_tokens)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
Loading…
Reference in a new issue