refiners/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py
2023-08-25 12:30:20 +02:00

103 lines
3.8 KiB
Python

from typing import Any, Protocol, cast
from pathlib import Path
from warnings import warn
import pytest
import torch
from torch import Tensor
from refiners.fluxion.utils import manual_seed
import refiners.fluxion.layers as fl
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
class DiffusersSDXL(Protocol):
unet: fl.Module
text_encoder: fl.Module
text_encoder_2: fl.Module
tokenizer: fl.Module
tokenizer_2: fl.Module
vae: fl.Module
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any:
...
def encode_prompt(
self,
prompt: str,
prompt_2: str | None = None,
negative_prompt: str | None = None,
negative_prompt_2: str | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
...
@pytest.fixture(scope="module")
def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path:
r = test_weights_path / "stabilityai" / "stable-diffusion-xl-base-1.0"
if not r.is_dir():
warn(message=f"could not find Stability SDXL base weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
from diffusers import DiffusionPipeline # type: ignore
return DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=stabilityai_sdxl_base_path) # type: ignore
@pytest.fixture(scope="module")
def double_text_encoder(test_weights_path: Path) -> DoubleTextEncoder:
text_encoder_l = CLIPTextEncoderL()
text_encoder_g_with_projection = CLIPTextEncoderG()
text_encoder_g_with_projection.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
text_encoder_l_path = test_weights_path / "CLIPTextEncoderL.safetensors"
text_encdoer_g_path = test_weights_path / "CLIPTextEncoderGWithProjection.safetensors"
text_encoder_l.load_from_safetensors(tensors_path=text_encoder_l_path)
text_encoder_g_with_projection.load_from_safetensors(tensors_path=text_encdoer_g_path)
linear = text_encoder_g_with_projection.pop(index=-1)
assert isinstance(linear, fl.Linear)
double_text_encoder = DoubleTextEncoder(
text_encoder_l=text_encoder_l, text_encoder_g=text_encoder_g_with_projection, projection=linear
)
return double_text_encoder
@torch.no_grad()
def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None:
manual_seed(seed=0)
prompt = "A photo of a pizza."
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = (
diffusers_sdxl.encode_prompt(prompt=prompt, negative_prompt="")
)
double_embedding, pooled_embedding = double_text_encoder(prompt)
assert double_embedding.shape == torch.Size([1, 77, 2048])
assert pooled_embedding.shape == torch.Size([1, 1280])
embedding_1, embedding_2 = cast(
tuple[Tensor, Tensor], prompt_embeds.split(split_size=[768, 1280], dim=-1) # type: ignore
)
rembedding_1, rembedding_2 = cast(
tuple[Tensor, Tensor], double_embedding.split(split_size=[768, 1280], dim=-1) # type: ignore
)
assert torch.allclose(input=embedding_1, other=rembedding_1, rtol=1e-3, atol=1e-3)
assert torch.allclose(input=embedding_2, other=rembedding_2, rtol=1e-3, atol=1e-3)
assert torch.allclose(input=pooled_embedding, other=pooled_prompt_embeds, rtol=1e-3, atol=1e-3)
negative_double_embedding, negative_pooled_embedding = double_text_encoder("")
assert torch.allclose(input=negative_double_embedding, other=negative_prompt_embeds, rtol=1e-3, atol=1e-3)
assert torch.allclose(input=negative_pooled_embedding, other=negative_pooled_prompt_embeds, rtol=1e-3, atol=1e-3)