mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-15 09:38:14 +00:00
71 lines
3.5 KiB
Python
71 lines
3.5 KiB
Python
|
import torch
|
||
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
|
||
|
from torch import Tensor
|
||
|
|
||
|
from refiners.fluxion.utils import manual_seed, no_grad
|
||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||
|
|
||
|
|
||
|
@no_grad()
|
||
|
def test_double_text_encoder(
|
||
|
diffusers_sdxl_pipeline: StableDiffusionXLPipeline,
|
||
|
refiners_sdxl_text_encoder: DoubleTextEncoder,
|
||
|
) -> None:
|
||
|
"""Compare our refiners implementation with the diffusers implementation."""
|
||
|
manual_seed(seed=0) # unnecessary, but just in case
|
||
|
prompt = "A photo of a pizza."
|
||
|
negative_prompt = ""
|
||
|
atol = 1e-6
|
||
|
|
||
|
( # encode text prompts using diffusers pipeline
|
||
|
diffusers_embeds,
|
||
|
diffusers_negative_embeds, # type: ignore
|
||
|
diffusers_pooled_embeds, # type: ignore
|
||
|
diffusers_negative_pooled_embeds, # type: ignore
|
||
|
) = diffusers_sdxl_pipeline.encode_prompt(prompt=prompt, negative_prompt=negative_prompt)
|
||
|
assert diffusers_negative_embeds is not None
|
||
|
assert isinstance(diffusers_pooled_embeds, Tensor)
|
||
|
assert isinstance(diffusers_negative_pooled_embeds, Tensor)
|
||
|
|
||
|
# encode text prompts using refiners model
|
||
|
refiners_embeds, refiners_pooled_embeds = refiners_sdxl_text_encoder(prompt)
|
||
|
refiners_negative_embeds, refiners_negative_pooled_embeds = refiners_sdxl_text_encoder("")
|
||
|
|
||
|
# check that the shapes are the same
|
||
|
assert diffusers_embeds.shape == refiners_embeds.shape == (1, 77, 2048)
|
||
|
assert diffusers_pooled_embeds.shape == refiners_pooled_embeds.shape == (1, 1280)
|
||
|
assert diffusers_negative_embeds.shape == refiners_negative_embeds.shape == (1, 77, 2048)
|
||
|
assert diffusers_negative_pooled_embeds.shape == refiners_negative_pooled_embeds.shape == (1, 1280)
|
||
|
|
||
|
# check that the values are close
|
||
|
assert torch.allclose(input=refiners_embeds, other=diffusers_embeds, atol=atol)
|
||
|
assert torch.allclose(input=refiners_negative_embeds, other=diffusers_negative_embeds, atol=atol)
|
||
|
assert torch.allclose(input=refiners_negative_pooled_embeds, other=diffusers_negative_pooled_embeds, atol=atol)
|
||
|
assert torch.allclose(input=refiners_pooled_embeds, other=diffusers_pooled_embeds, atol=atol)
|
||
|
|
||
|
|
||
|
@no_grad()
|
||
|
def test_double_text_encoder_batched(refiners_sdxl_text_encoder: DoubleTextEncoder) -> None:
|
||
|
"""Check that encoding two prompts works as expected whether batched or not."""
|
||
|
manual_seed(seed=0) # unnecessary, but just in case
|
||
|
prompt1 = "A photo of a pizza."
|
||
|
prompt2 = "A giant duck."
|
||
|
atol = 1e-6
|
||
|
|
||
|
# encode the two prompts at once
|
||
|
embeds_batched, pooled_embeds_batched = refiners_sdxl_text_encoder([prompt1, prompt2])
|
||
|
assert embeds_batched.shape == (2, 77, 2048)
|
||
|
assert pooled_embeds_batched.shape == (2, 1280)
|
||
|
|
||
|
# encode the prompts one by one
|
||
|
embeds_1, pooled_embeds_1 = refiners_sdxl_text_encoder(prompt1)
|
||
|
embeds_2, pooled_embeds_2 = refiners_sdxl_text_encoder(prompt2)
|
||
|
assert embeds_1.shape == embeds_2.shape == (1, 77, 2048)
|
||
|
assert pooled_embeds_1.shape == pooled_embeds_2.shape == (1, 1280)
|
||
|
|
||
|
# check that the values are close
|
||
|
assert torch.allclose(input=embeds_1, other=embeds_batched[0].unsqueeze(0), atol=atol)
|
||
|
assert torch.allclose(input=pooled_embeds_1, other=pooled_embeds_batched[0].unsqueeze(0), atol=atol)
|
||
|
assert torch.allclose(input=embeds_2, other=embeds_batched[1].unsqueeze(0), atol=atol)
|
||
|
assert torch.allclose(input=pooled_embeds_2, other=pooled_embeds_batched[1].unsqueeze(0), atol=atol)
|