From 3ddd258d36a1b8146796ea33974f4075cbf8d3a7 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Tue, 30 Jan 2024 17:47:06 +0100 Subject: [PATCH] add a test for noise schedules --- tests/foundationals/latent_diffusion/test_solvers.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/foundationals/latent_diffusion/test_solvers.py b/tests/foundationals/latent_diffusion/test_solvers.py index f7c8c4d..be5167d 100644 --- a/tests/foundationals/latent_diffusion/test_solvers.py +++ b/tests/foundationals/latent_diffusion/test_solvers.py @@ -5,7 +5,7 @@ import pytest from torch import Tensor, allclose, device as Device, equal, isclose, randn from refiners.fluxion import manual_seed -from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler +from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, NoiseSchedule def test_ddpm_diffusers(): @@ -135,3 +135,11 @@ def test_scheduler_device(test_device: Device): noise = randn(1, 4, 32, 32, device=test_device) noised = scheduler.add_noise(x, noise, scheduler.first_inference_step) assert noised.device == test_device + + +@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS]) +def test_scheduler_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device): + scheduler = DDIM(num_inference_steps=30, device=test_device, noise_schedule=noise_schedule) + assert len(scheduler.scale_factors) == 1000 + assert scheduler.scale_factors[0] == 1 - scheduler.initial_diffusion_rate + assert scheduler.scale_factors[-1] == 1 - scheduler.final_diffusion_rate