diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 4850e79..4ced66c 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -23,7 +23,7 @@ from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.latent_diffusion.restart import Restart -from refiners.foundationals.latent_diffusion.schedulers import DDIM +from refiners.foundationals.latent_diffusion.schedulers import DDIM, EulerScheduler from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL @@ -65,6 +65,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image: return Image.open(ref_path / "expected_std_random_init.png").convert("RGB") +@pytest.fixture +def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "expected_std_random_init_euler.png").convert("RGB") + + @pytest.fixture def expected_karras_random_init(ref_path: Path) -> Image.Image: return Image.open(ref_path / "expected_karras_random_init.png").convert("RGB") @@ -438,6 +443,24 @@ def sd15_ddim_karras( return sd15 +@pytest.fixture +def sd15_euler( + text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device +) -> StableDiffusion_1: + if test_device.type == "cpu": + warn("not running on CPU, skipping") + pytest.skip() + + euler_scheduler = EulerScheduler(num_inference_steps=30) + sd15 = StableDiffusion_1(scheduler=euler_scheduler, device=test_device) + + sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) + sd15.lda.load_from_safetensors(lda_weights) + sd15.unet.load_from_safetensors(unet_weights_std) + + return sd15 + + @pytest.fixture def sd15_ddim_lda_ft_mse( text_encoder_weights: Path, lda_ft_mse_weights: Path, unet_weights_std: Path, test_device: torch.device @@ -529,6 +552,37 @@ def test_diffusion_std_random_init( ensure_similar_images(predicted_image, expected_image_std_random_init) +@no_grad() +def test_diffusion_std_random_init_euler( + sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device +): + sd15 = sd15_euler + euler_scheduler = sd15_euler.scheduler + assert isinstance(euler_scheduler, EulerScheduler) + n_steps = 30 + + prompt = "a cute cat, detailed high-quality professional image" + negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + + sd15.set_num_inference_steps(n_steps) + + manual_seed(2) + x = torch.randn(1, 4, 64, 64, device=test_device) + x = x * euler_scheduler.init_noise_sigma + + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) + + ensure_similar_images(predicted_image, expected_image_std_random_init_euler) + + @no_grad() def test_diffusion_karras_random_init( sd15_ddim_karras: StableDiffusion_1, expected_karras_random_init: Image.Image, test_device: torch.device diff --git a/tests/e2e/test_diffusion_ref/expected_std_random_init_euler.png b/tests/e2e/test_diffusion_ref/expected_std_random_init_euler.png new file mode 100644 index 0000000..b79390c Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_std_random_init_euler.png differ