refiners/tests/foundationals/clip/test_concepts.py
Pierre Chapuis 471ef91d1c make __getattr__ on Module return object, not Any
PyTorch chose to make it Any because they expect its users' code
to be "highly dynamic": https://github.com/pytorch/pytorch/pull/104321

It is not the case for us, in Refiners having untyped code
goes contrary to one of our core principles.

Note that there is currently an open PR in PyTorch to
return `Module | Tensor`, but in practice this is not always
correct either: https://github.com/pytorch/pytorch/pull/115074

I also moved Residuals-related code from SD1 to latent_diffusion
because SDXL should not depend on SD1.
2024-02-06 11:32:18 +01:00

137 lines
5.2 KiB
Python

from pathlib import Path
from warnings import warn
import pytest
import torch
import transformers # type: ignore
from diffusers import StableDiffusionPipeline # type: ignore
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import load_from_safetensors, load_tensors, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
PROMPTS = [
"a cute cat", # a simple prompt
"This artwork is inspired by <gta5-artwork> and uses a <cat-toy> as a prop", # prompt with two added concepts
]
@pytest.fixture(scope="module")
def our_encoder_with_new_concepts(
test_weights_path: Path,
test_device: torch.device,
cat_embedding_textual_inversion: torch.Tensor,
gta5_artwork_embedding_textual_inversion: torch.Tensor,
) -> CLIPTextEncoderL:
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
encoder = CLIPTextEncoderL(device=test_device)
tensors = load_from_safetensors(weights)
encoder.load_state_dict(tensors)
concept_extender = ConceptExtender(encoder)
concept_extender.add_concept("<cat-toy>", cat_embedding_textual_inversion)
concept_extender.add_concept("<gta5-artwork>", gta5_artwork_embedding_textual_inversion)
concept_extender.inject()
return encoder
@pytest.fixture(scope="module")
def ref_sd15_with_new_concepts(
runwayml_weights_path: Path, test_textual_inversion_path: Path, test_device: torch.device
) -> StableDiffusionPipeline:
pipe = StableDiffusionPipeline.from_pretrained(runwayml_weights_path).to(test_device) # type: ignore
assert isinstance(pipe, StableDiffusionPipeline)
pipe.load_textual_inversion(test_textual_inversion_path / "cat-toy") # type: ignore
pipe.load_textual_inversion(test_textual_inversion_path / "gta5-artwork") # type: ignore
return pipe
@pytest.fixture(scope="module")
def runwayml_weights_path(test_weights_path: Path):
r = test_weights_path / "runwayml" / "stable-diffusion-v1-5"
if not r.is_dir():
warn(f"could not find RunwayML weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def ref_tokenizer_with_new_concepts(ref_sd15_with_new_concepts: StableDiffusionPipeline) -> transformers.CLIPTokenizer:
return ref_sd15_with_new_concepts.tokenizer # type: ignore
@pytest.fixture(scope="module")
def ref_encoder_with_new_concepts(ref_sd15_with_new_concepts: StableDiffusionPipeline) -> transformers.CLIPTextModel:
return ref_sd15_with_new_concepts.text_encoder
@pytest.fixture(params=PROMPTS)
def prompt(request: pytest.FixtureRequest):
return request.param
@pytest.fixture(scope="module")
def gta5_artwork_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]
@pytest.fixture(scope="module")
def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return load_tensors(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"]
def test_tokenizer_with_special_character():
clip_tokenizer_chain = fl.Chain(CLIPTokenizer())
original_clip_tokenizer = clip_tokenizer_chain.layer("CLIPTokenizer", CLIPTokenizer)
token_extender = TokenExtender(original_clip_tokenizer)
new_token_id = max(original_clip_tokenizer.token_to_id_mapping.values()) + 42
token_extender.add_token("*", new_token_id)
token_extender.inject(clip_tokenizer_chain)
adapted_clip_tokenizer = clip_tokenizer_chain.ensure_find(CLIPTokenizer)
assert torch.allclose(
adapted_clip_tokenizer.encode("*"),
torch.Tensor(
[
adapted_clip_tokenizer.start_of_text_token_id,
new_token_id,
adapted_clip_tokenizer.end_of_text_token_id,
]
).to(torch.int64),
)
def test_encoder(
prompt: str,
ref_tokenizer_with_new_concepts: transformers.CLIPTokenizer,
ref_encoder_with_new_concepts: transformers.CLIPTextModel,
our_encoder_with_new_concepts: CLIPTextEncoderL,
test_device: torch.device,
):
ref_tokens = ref_tokenizer_with_new_concepts( # type: ignore
prompt,
padding="max_length",
max_length=ref_tokenizer_with_new_concepts.model_max_length, # type: ignore
truncation=True,
return_tensors="pt",
).input_ids
assert isinstance(ref_tokens, torch.Tensor)
tokenizer = our_encoder_with_new_concepts.ensure_find(CLIPTokenizer)
our_tokens = tokenizer(prompt)
assert torch.equal(our_tokens, ref_tokens)
with no_grad():
ref_embeddings = ref_encoder_with_new_concepts(ref_tokens.to(test_device))[0]
our_embeddings = our_encoder_with_new_concepts(prompt)
assert ref_embeddings.shape == (1, 77, 768)
assert our_embeddings.shape == (1, 77, 768)
# See `test_encoder` in test_text_encoder.py for details about the tolerance (0.04)
assert (our_embeddings - ref_embeddings).abs().max() < 0.04