diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 3ef40d3..dcbf30e 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -25,6 +25,7 @@ from refiners.foundationals.latent_diffusion.restart import Restart from refiners.foundationals.latent_diffusion.schedulers import DDIM from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.clip.concepts import ConceptExtender +from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL @@ -66,6 +67,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image: return Image.open(ref_path / "expected_std_random_init.png").convert("RGB") +@pytest.fixture +def expected_karras_random_init(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "expected_karras_random_init.png").convert("RGB") + + @pytest.fixture def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image: return Image.open(ref_path / "expected_std_random_init_sag.png").convert("RGB") @@ -416,6 +422,24 @@ def sd15_ddim( return sd15 +@pytest.fixture +def sd15_ddim_karras( + text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device +) -> StableDiffusion_1: + if test_device.type == "cpu": + warn("not running on CPU, skipping") + pytest.skip() + + ddim_scheduler = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS) + sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device) + + sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) + sd15.lda.load_from_safetensors(lda_weights) + sd15.unet.load_from_safetensors(unet_weights_std) + + return sd15 + + @pytest.fixture def sd15_ddim_lda_ft_mse( text_encoder_weights: Path, lda_ft_mse_weights: Path, unet_weights_std: Path, test_device: torch.device @@ -507,6 +531,31 @@ def test_diffusion_std_random_init( ensure_similar_images(predicted_image, expected_image_std_random_init) +@torch.no_grad() +def test_diffusion_karras_random_init( + sd15_ddim_karras: StableDiffusion_1, expected_karras_random_init: Image.Image, test_device: torch.device +): + sd15 = sd15_ddim_karras + + prompt = "a cute cat, detailed high-quality professional image" + negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + + manual_seed(2) + x = torch.randn(1, 4, 64, 64, device=test_device) + + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) + + ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98) + + @torch.no_grad() def test_diffusion_std_random_init_float16( sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device diff --git a/tests/e2e/test_diffusion_ref/expected_karras_random_init.png b/tests/e2e/test_diffusion_ref/expected_karras_random_init.png new file mode 100644 index 0000000..75a4d76 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_karras_random_init.png differ