improve DPM solver test

This commit is contained in:
Pierre Chapuis 2024-01-18 14:30:13 +01:00
parent 999e429697
commit ce3035923b

View file

@ -18,14 +18,21 @@ def test_ddpm_diffusers():
assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps) assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps)
def test_dpm_solver_diffusers(): @pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
manual_seed(0) manual_seed(0)
diffusers_scheduler = DiffuserScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012) diffusers_scheduler = DiffuserScheduler(
diffusers_scheduler.set_timesteps(30) beta_schedule="scaled_linear",
refiners_scheduler = DPMSolver(num_inference_steps=30) beta_start=0.00085,
beta_end=0.012,
lower_order_final=False,
euler_at_final=last_step_first_order,
)
diffusers_scheduler.set_timesteps(n_steps)
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)
sample = randn(1, 3, 32, 32) sample = randn(1, 3, 32, 32)
noise = randn(1, 3, 32, 32) noise = randn(1, 3, 32, 32)