fix scheduler device choice

This commit is contained in:
Pierre Chapuis 2023-09-21 11:47:11 +02:00
parent 282578ddc0
commit cd1fdb5585
2 changed files with 16 additions and 1 deletions

View file

@ -39,6 +39,7 @@ class Scheduler(ABC):
start=initial_diffusion_rate**0.5, start=initial_diffusion_rate**0.5,
end=final_diffusion_rate**0.5, end=final_diffusion_rate**0.5,
steps=num_train_timesteps, steps=num_train_timesteps,
device=device,
dtype=dtype, dtype=dtype,
) )
** 2 ** 2

View file

@ -1,7 +1,9 @@
import pytest
from typing import cast from typing import cast
from warnings import warn
from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver, DDIM from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver, DDIM
from refiners.fluxion import norm, manual_seed from refiners.fluxion import norm, manual_seed
from torch import linspace, float32, randn, Tensor, allclose from torch import linspace, float32, randn, Tensor, allclose, device as Device
def test_scheduler_utils(): def test_scheduler_utils():
@ -67,3 +69,15 @@ def test_ddim_solver_diffusers():
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step) refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}" assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
def test_scheduler_device(test_device: Device):
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
scheduler = DDIM(num_inference_steps=30, device=test_device)
x = randn(1, 4, 32, 32, device=test_device)
noise = randn(1, 4, 32, 32, device=test_device)
noised = scheduler.add_noise(x, noise, scheduler.steps[0])
assert noised.device == test_device