write StyleAligned e2e test

This commit is contained in:
Laurent 2024-02-15 14:11:11 +00:00 committed by Laureηt
parent 60c0780fe7
commit da3c3602fb
3 changed files with 83 additions and 0 deletions

View file

@ -30,6 +30,7 @@ from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSc
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.control_lora import ControlLoraAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter
from tests.utils import ensure_similar_images
@ -150,6 +151,11 @@ def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB")
@pytest.fixture
def expected_style_aligned(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_style_aligned.png").convert(mode="RGB")
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
def controlnet_data(
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
@ -2140,3 +2146,79 @@ def test_hello_world(
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image)
@no_grad()
def test_style_aligned(
sdxl_ddim_lda_fp16_fix: StableDiffusion_XL,
expected_style_aligned: Image.Image,
):
sdxl = sdxl_ddim_lda_fp16_fix.to(dtype=torch.float16)
sdxl.dtype = torch.float16 # FIXME: should not be necessary
style_aligned_adapter = StyleAlignedAdapter(sdxl.unet)
style_aligned_adapter.inject()
set_of_prompts = [
"a toy train. macro photo. 3d game asset",
"a toy airplane. macro photo. 3d game asset",
"a toy bicycle. macro photo. 3d game asset",
"a toy car. macro photo. 3d game asset",
"a toy boat. macro photo. 3d game asset",
]
# create (context) embeddings from prompts
# TODO: replace this logic with https://github.com/finegrain-ai/refiners/pull/263 when it gets merged
unconds: list[torch.Tensor] = []
conds: list[torch.Tensor] = []
pooled_unconds: list[torch.Tensor] = []
pooled_conds: list[torch.Tensor] = []
for prompt in set_of_prompts:
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(text=prompt)
uncond, cond = clip_text_embedding.chunk(2)
pooled_uncond, pooled_cond = pooled_text_embedding.chunk(2)
unconds.append(uncond)
conds.append(cond)
pooled_unconds.append(pooled_uncond)
pooled_conds.append(pooled_cond)
uncond = torch.cat(unconds, dim=0)
cond = torch.cat(conds, dim=0)
pooled_uncond = torch.cat(pooled_unconds, dim=0)
pooled_cond = torch.cat(pooled_conds, dim=0)
clip_text_embedding = torch.cat((uncond, cond), dim=0)
pooled_text_embedding = torch.cat((pooled_uncond, pooled_cond), dim=0)
time_ids = sdxl.default_time_ids.repeat(len(set_of_prompts), 1)
# initialize latents
manual_seed(seed=2)
x = torch.randn(
(len(set_of_prompts), 4, 128, 128),
device=sdxl.device,
dtype=sdxl.dtype,
)
# denoise
for step in sdxl.steps:
x = sdxl(
x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
# decode latents
predicted_images = [sdxl.lda.decode_latents(latent.unsqueeze(0)) for latent in x]
# tile all images horizontally
merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024))
for i in range(len(predicted_images)):
merged_image.paste(predicted_images[i], (i * 1024, 0))
# compare against reference image
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99)

View file

@ -56,6 +56,7 @@ Special cases:
- `expected_controllora_PyraCanny.png`
- `expected_controllora_PyraCanny+CPDS.png`
- `expected_controllora_disabled.png`
- `expected_style_aligned.png`
## Other images

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.4 MiB