From ce3035923ba71bcb5044708d2f1c37fd1d6722e9 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Thu, 18 Jan 2024 14:30:13 +0100 Subject: [PATCH] improve DPM solver test --- .../latent_diffusion/test_schedulers.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/foundationals/latent_diffusion/test_schedulers.py b/tests/foundationals/latent_diffusion/test_schedulers.py index a6a524e..91956ca 100644 --- a/tests/foundationals/latent_diffusion/test_schedulers.py +++ b/tests/foundationals/latent_diffusion/test_schedulers.py @@ -18,14 +18,21 @@ def test_ddpm_diffusers(): assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps) -def test_dpm_solver_diffusers(): +@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)]) +def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool): from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore manual_seed(0) - diffusers_scheduler = DiffuserScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012) - diffusers_scheduler.set_timesteps(30) - refiners_scheduler = DPMSolver(num_inference_steps=30) + diffusers_scheduler = DiffuserScheduler( + beta_schedule="scaled_linear", + beta_start=0.00085, + beta_end=0.012, + lower_order_final=False, + euler_at_final=last_step_first_order, + ) + diffusers_scheduler.set_timesteps(n_steps) + refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order) sample = randn(1, 3, 32, 32) noise = randn(1, 3, 32, 32)