mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
104 lines
3.8 KiB
Python
104 lines
3.8 KiB
Python
import torch
|
|
import pytest
|
|
|
|
from warnings import warn
|
|
from pathlib import Path
|
|
|
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
|
from refiners.fluxion.utils import load_from_safetensors
|
|
|
|
import transformers # type: ignore
|
|
|
|
|
|
long_prompt = """
|
|
Above these apparent hieroglyphics was a figure of evidently pictorial intent,
|
|
though its impressionistic execution forbade a very clear idea of its nature.
|
|
It seemed to be a sort of monster, or symbol representing a monster, of a form
|
|
which only a diseased fancy could conceive. If I say that my somewhat extravagant
|
|
imagination yielded simultaneous pictures of an octopus, a dragon, and a human
|
|
caricature, I shall not be unfaithful to the spirit of the thing. A pulpy,
|
|
tentacled head surmounted a grotesque and scaly body with rudimentary wings;
|
|
but it was the general outline of the whole which made it most shockingly frightful.
|
|
Behind the figure was a vague suggestion of a Cyclopean architectural background.
|
|
"""
|
|
|
|
PROMPTS = [
|
|
"", # empty
|
|
"a cute cat", # padded
|
|
long_prompt, # truncated
|
|
"64k", # FG-362 - encoded as 3 tokens
|
|
]
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPTextEncoderL:
|
|
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
|
|
if not weights.is_file():
|
|
warn(f"could not find weights at {weights}, skipping")
|
|
pytest.skip(allow_module_level=True)
|
|
encoder = CLIPTextEncoderL(device=test_device)
|
|
tensors = load_from_safetensors(weights)
|
|
encoder.load_state_dict(tensors)
|
|
return encoder
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def runwayml_weights_path(test_weights_path: Path):
|
|
r = test_weights_path / "runwayml" / "stable-diffusion-v1-5"
|
|
if not r.is_dir():
|
|
warn(f"could not find RunwayML weights at {r}, skipping")
|
|
pytest.skip(allow_module_level=True)
|
|
return r
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def ref_tokenizer(runwayml_weights_path: Path) -> transformers.CLIPTokenizer:
|
|
return transformers.CLIPTokenizer.from_pretrained(runwayml_weights_path, subfolder="tokenizer") # type: ignore
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def ref_encoder(runwayml_weights_path: Path, test_device: torch.device) -> transformers.CLIPTextModel:
|
|
return transformers.CLIPTextModel.from_pretrained(runwayml_weights_path, subfolder="text_encoder").to(test_device) # type: ignore
|
|
|
|
|
|
def test_basics(ref_tokenizer: transformers.CLIPTokenizer, our_encoder: CLIPTextEncoderL):
|
|
assert ref_tokenizer.model_max_length == 77 # type: ignore
|
|
assert our_encoder.max_sequence_length == 77
|
|
|
|
|
|
@pytest.fixture(params=PROMPTS)
|
|
def prompt(request: pytest.FixtureRequest):
|
|
return request.param
|
|
|
|
|
|
def test_encoder(
|
|
prompt: str,
|
|
ref_tokenizer: transformers.CLIPTokenizer,
|
|
ref_encoder: transformers.CLIPTextModel,
|
|
our_encoder: CLIPTextEncoderL,
|
|
test_device: torch.device,
|
|
):
|
|
ref_tokens = ref_tokenizer( # type: ignore
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=ref_tokenizer.model_max_length, # type: ignore
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
).input_ids
|
|
assert isinstance(ref_tokens, torch.Tensor)
|
|
our_tokens = our_encoder.tokenizer(prompt, sequence_length=our_encoder.max_sequence_length)
|
|
assert torch.equal(our_tokens, ref_tokens)
|
|
|
|
with torch.no_grad():
|
|
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0]
|
|
our_embeddings = our_encoder(our_tokens.to(test_device))
|
|
|
|
assert ref_embeddings.shape == (1, 77, 768)
|
|
assert our_embeddings.shape == (1, 77, 768)
|
|
|
|
# FG-336 - Not strictly equal because we do not use the same implementation
|
|
# of self-attention. We use `scaled_dot_product_attention` which can have
|
|
# numerical differences depending on the backend.
|
|
# Also we use FP16 weights.
|
|
assert (our_embeddings - ref_embeddings).abs().max() < 0.01
|