mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 22:28:46 +00:00
111 lines
4.2 KiB
Python
111 lines
4.2 KiB
Python
from typing import Any, Protocol, cast
|
|
from pathlib import Path
|
|
from warnings import warn
|
|
import pytest
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from refiners.fluxion.utils import manual_seed
|
|
import refiners.fluxion.layers as fl
|
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL
|
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
|
|
|
|
|
class DiffusersSDXL(Protocol):
|
|
unet: fl.Module
|
|
text_encoder: fl.Module
|
|
text_encoder_2: fl.Module
|
|
tokenizer: fl.Module
|
|
tokenizer_2: fl.Module
|
|
vae: fl.Module
|
|
|
|
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any:
|
|
...
|
|
|
|
def encode_prompt(
|
|
self,
|
|
prompt: str,
|
|
prompt_2: str | None = None,
|
|
negative_prompt: str | None = None,
|
|
negative_prompt_2: str | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
...
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path:
|
|
r = test_weights_path / "stabilityai" / "stable-diffusion-xl-base-1.0"
|
|
if not r.is_dir():
|
|
warn(message=f"could not find Stability SDXL base weights at {r}, skipping")
|
|
pytest.skip(allow_module_level=True)
|
|
return r
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
|
|
from diffusers import DiffusionPipeline # type: ignore
|
|
|
|
return DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=stabilityai_sdxl_base_path) # type: ignore
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def double_text_encoder(test_weights_path: Path) -> DoubleTextEncoder:
|
|
text_encoder_l = CLIPTextEncoderL()
|
|
text_encoder_g_with_projection = CLIPTextEncoderG()
|
|
text_encoder_g_with_projection.append(module=fl.Linear(in_features=1280, out_features=1280, bias=False))
|
|
|
|
text_encoder_l_path = test_weights_path / "CLIPTextEncoderL.safetensors"
|
|
text_encoder_g_path = test_weights_path / "CLIPTextEncoderGWithProjection.safetensors"
|
|
|
|
if not text_encoder_l_path.is_file():
|
|
warn(f"could not find weights at {text_encoder_l_path}, skipping")
|
|
pytest.skip(allow_module_level=True)
|
|
|
|
if not text_encoder_g_path.is_file():
|
|
warn(f"could not find weights at {text_encoder_g_path}, skipping")
|
|
pytest.skip(allow_module_level=True)
|
|
|
|
text_encoder_l.load_from_safetensors(tensors_path=text_encoder_l_path)
|
|
text_encoder_g_with_projection.load_from_safetensors(tensors_path=text_encoder_g_path)
|
|
|
|
linear = text_encoder_g_with_projection.pop(index=-1)
|
|
assert isinstance(linear, fl.Linear)
|
|
|
|
double_text_encoder = DoubleTextEncoder(
|
|
text_encoder_l=text_encoder_l, text_encoder_g=text_encoder_g_with_projection, projection=linear
|
|
)
|
|
|
|
return double_text_encoder
|
|
|
|
|
|
@torch.no_grad()
|
|
def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None:
|
|
manual_seed(seed=0)
|
|
prompt = "A photo of a pizza."
|
|
|
|
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = (
|
|
diffusers_sdxl.encode_prompt(prompt=prompt, negative_prompt="")
|
|
)
|
|
|
|
double_embedding, pooled_embedding = double_text_encoder(prompt)
|
|
|
|
assert double_embedding.shape == torch.Size([1, 77, 2048])
|
|
assert pooled_embedding.shape == torch.Size([1, 1280])
|
|
|
|
embedding_1, embedding_2 = cast(
|
|
tuple[Tensor, Tensor], prompt_embeds.split(split_size=[768, 1280], dim=-1) # type: ignore
|
|
)
|
|
|
|
rembedding_1, rembedding_2 = cast(
|
|
tuple[Tensor, Tensor], double_embedding.split(split_size=[768, 1280], dim=-1) # type: ignore
|
|
)
|
|
|
|
assert torch.allclose(input=embedding_1, other=rembedding_1, rtol=1e-3, atol=1e-3)
|
|
assert torch.allclose(input=embedding_2, other=rembedding_2, rtol=1e-3, atol=1e-3)
|
|
assert torch.allclose(input=pooled_embedding, other=pooled_prompt_embeds, rtol=1e-3, atol=1e-3)
|
|
|
|
negative_double_embedding, negative_pooled_embedding = double_text_encoder("")
|
|
|
|
assert torch.allclose(input=negative_double_embedding, other=negative_prompt_embeds, rtol=1e-3, atol=1e-3)
|
|
assert torch.allclose(input=negative_pooled_embedding, other=negative_pooled_prompt_embeds, rtol=1e-3, atol=1e-3)
|