refiners/tests/foundationals/latent_diffusion/test_sdxl_unet.py

70 lines
2.5 KiB
Python
Raw Normal View History

2023-08-04 13:28:41 +00:00
from typing import Any
from pathlib import Path
from warnings import warn
import pytest
import torch
from refiners.foundationals.latent_diffusion.sdxl_unet import SDXLUNet
from refiners.fluxion.utils import compare_models
@pytest.fixture(scope="module")
def stabilityai_sdxl_base_path(test_weights_path: Path) -> Path:
r = test_weights_path / "stabilityai" / "stable-diffusion-xl-base-0.9"
if not r.is_dir():
warn(f"could not find Stability SDXL base weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def diffusers_sdxl(stabilityai_sdxl_base_path: Path) -> Any:
from diffusers import DiffusionPipeline # type: ignore
return DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=stabilityai_sdxl_base_path) # type: ignore
@pytest.fixture(scope="module")
def diffusers_sdxl_unet(diffusers_sdxl: Any) -> Any:
return diffusers_sdxl.unet
@pytest.fixture(scope="module")
def sdxl_unet_weights_std(test_weights_path: Path) -> Path:
unet_weights_std = test_weights_path / "sdxl-unet.safetensors"
if not unet_weights_std.is_file():
warn(message=f"could not find weights at {unet_weights_std}, skipping")
pytest.skip(allow_module_level=True)
return unet_weights_std
@pytest.fixture(scope="module")
def refiners_sdxl_unet(sdxl_unet_weights_std: Path) -> SDXLUNet:
unet = SDXLUNet(in_channels=4)
unet.load_from_safetensors(tensors_path=sdxl_unet_weights_std)
return unet
@torch.no_grad()
def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None:
torch.manual_seed(seed=0) # type: ignore
x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor(data=[0])
clip_text_embeddings = torch.randn(1, 77, 2048)
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
source_args = (x, timestep, clip_text_embeddings, None, None, None, None, added_cond_kwargs)
refiners_sdxl_unet.set_timestep(timestep=timestep)
refiners_sdxl_unet.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
refiners_sdxl_unet.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
refiners_sdxl_unet.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
target_args = (x,)
assert compare_models(
source_model=diffusers_sdxl_unet,
target_model=refiners_sdxl_unet,
source_args=source_args,
target_args=target_args,
threshold=1e-2,
)