add a test for noise schedules

This commit is contained in:
Pierre Chapuis 2024-01-30 17:47:06 +01:00 committed by Cédric Deltheil
parent df843f5226
commit 3ddd258d36

View file

@ -5,7 +5,7 @@ import pytest
from torch import Tensor, allclose, device as Device, equal, isclose, randn
from refiners.fluxion import manual_seed
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, NoiseSchedule
def test_ddpm_diffusers():
@ -135,3 +135,11 @@ def test_scheduler_device(test_device: Device):
noise = randn(1, 4, 32, 32, device=test_device)
noised = scheduler.add_noise(x, noise, scheduler.first_inference_step)
assert noised.device == test_device
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
def test_scheduler_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
scheduler = DDIM(num_inference_steps=30, device=test_device, noise_schedule=noise_schedule)
assert len(scheduler.scale_factors) == 1000
assert scheduler.scale_factors[0] == 1 - scheduler.initial_diffusion_rate
assert scheduler.scale_factors[-1] == 1 - scheduler.final_diffusion_rate