diff --git a/src/refiners/foundationals/latent_diffusion/__init__.py b/src/refiners/foundationals/latent_diffusion/__init__.py index b301aac..42bdaf8 100644 --- a/src/refiners/foundationals/latent_diffusion/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/__init__.py @@ -4,6 +4,7 @@ from refiners.foundationals.latent_diffusion.auto_encoder import ( from refiners.foundationals.clip.text_encoder import ( CLIPTextEncoderL, ) +from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver from refiners.foundationals.latent_diffusion.stable_diffusion_1 import ( StableDiffusion_1, @@ -36,4 +37,5 @@ __all__ = [ "Scheduler", "CLIPTextEncoderL", "LatentDiffusionAutoencoder", + "SDFreeUAdapter", ] diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index fca8dfa..3ef40d3 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -17,6 +17,7 @@ from refiners.foundationals.latent_diffusion import ( SD1T2IAdapter, SDXLIPAdapter, SDXLT2IAdapter, + SDFreeUAdapter, ) from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget @@ -227,6 +228,11 @@ def expected_restart(ref_path: Path) -> Image.Image: return Image.open(fp=ref_path / "expected_restart.png").convert(mode="RGB") +@pytest.fixture +def expected_freeu(ref_path: Path) -> Image.Image: + return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB") + + @pytest.fixture def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor: return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")[""] # type: ignore @@ -1604,3 +1610,37 @@ def test_restart( predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98) + + +@torch.no_grad() +def test_freeu( + sd15_std: StableDiffusion_1, + expected_freeu: Image.Image, +): + sd15 = sd15_std + n_steps = 50 + first_step = 1 + + prompt = "best quality, high quality cute cat" + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + + sd15.set_num_inference_steps(n_steps) + + SDFreeUAdapter( + sd15.unet, backbone_scales=[1.2, 1.2, 1.2, 1.4, 1.4, 1.4], skip_scales=[0.9, 0.9, 0.9, 0.2, 0.2, 0.2] + ).inject() + + manual_seed(9752) + x = sd15.init_latents(size=(512, 512), first_step=first_step).to(device=sd15.device, dtype=sd15.dtype) + + for step in sd15.steps[first_step:]: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) + + ensure_similar_images(predicted_image, expected_freeu) diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index 62cf504..2402be5 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -45,6 +45,7 @@ Special cases: - `expected_image_sdxl_ip_adapter_plus_woman.png` - `expected_cutecat_sdxl_ddim_random_init_sag.png` - `expected_restart.png` + - `expected_freeu.png` ## Other images diff --git a/tests/e2e/test_diffusion_ref/expected_freeu.png b/tests/e2e/test_diffusion_ref/expected_freeu.png new file mode 100644 index 0000000..01efceb Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_freeu.png differ diff --git a/tests/foundationals/latent_diffusion/test_freeu.py b/tests/foundationals/latent_diffusion/test_freeu.py new file mode 100644 index 0000000..786cfe3 --- /dev/null +++ b/tests/foundationals/latent_diffusion/test_freeu.py @@ -0,0 +1,41 @@ +from typing import Iterator + +import pytest + +from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet +from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter, FreeUResidualConcatenator + + +@pytest.fixture(scope="module", params=[True, False]) +def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet | SDXLUNet]: + xl: bool = request.param + unet = SDXLUNet(in_channels=4) if xl else SD1UNet(in_channels=4) + yield unet + + +def test_freeu_adapter(unet: SD1UNet | SDXLUNet) -> None: + freeu = SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9]) + + assert len(list(unet.walk(FreeUResidualConcatenator))) == 0 + + with pytest.raises(AssertionError) as exc: + freeu.eject() + assert "could not find" in str(exc.value) + + freeu.inject() + assert len(list(unet.walk(FreeUResidualConcatenator))) == 2 + + freeu.eject() + assert len(list(unet.walk(FreeUResidualConcatenator))) == 0 + + +def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None: + num_blocks = len(unet.UpBlocks) + + with pytest.raises(AssertionError): + SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1)) + + +def test_freeu_adapter_inconsistent_scales(unet: SD1UNet | SDXLUNet) -> None: + with pytest.raises(AssertionError): + SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9, 0.9])