mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add a test for noise schedules
This commit is contained in:
parent
df843f5226
commit
3ddd258d36
|
@ -5,7 +5,7 @@ import pytest
|
|||
from torch import Tensor, allclose, device as Device, equal, isclose, randn
|
||||
|
||||
from refiners.fluxion import manual_seed
|
||||
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler
|
||||
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, NoiseSchedule
|
||||
|
||||
|
||||
def test_ddpm_diffusers():
|
||||
|
@ -135,3 +135,11 @@ def test_scheduler_device(test_device: Device):
|
|||
noise = randn(1, 4, 32, 32, device=test_device)
|
||||
noised = scheduler.add_noise(x, noise, scheduler.first_inference_step)
|
||||
assert noised.device == test_device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
|
||||
def test_scheduler_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
|
||||
scheduler = DDIM(num_inference_steps=30, device=test_device, noise_schedule=noise_schedule)
|
||||
assert len(scheduler.scale_factors) == 1000
|
||||
assert scheduler.scale_factors[0] == 1 - scheduler.initial_diffusion_rate
|
||||
assert scheduler.scale_factors[-1] == 1 - scheduler.final_diffusion_rate
|
||||
|
|
Loading…
Reference in a new issue