mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
fix scheduler device choice
This commit is contained in:
parent
282578ddc0
commit
cd1fdb5585
|
@ -39,6 +39,7 @@ class Scheduler(ABC):
|
|||
start=initial_diffusion_rate**0.5,
|
||||
end=final_diffusion_rate**0.5,
|
||||
steps=num_train_timesteps,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
** 2
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import pytest
|
||||
from typing import cast
|
||||
from warnings import warn
|
||||
from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver, DDIM
|
||||
from refiners.fluxion import norm, manual_seed
|
||||
from torch import linspace, float32, randn, Tensor, allclose
|
||||
from torch import linspace, float32, randn, Tensor, allclose, device as Device
|
||||
|
||||
|
||||
def test_scheduler_utils():
|
||||
|
@ -67,3 +69,15 @@ def test_ddim_solver_diffusers():
|
|||
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
|
||||
|
||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||
|
||||
|
||||
def test_scheduler_device(test_device: Device):
|
||||
if test_device.type == "cpu":
|
||||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
scheduler = DDIM(num_inference_steps=30, device=test_device)
|
||||
x = randn(1, 4, 32, 32, device=test_device)
|
||||
noise = randn(1, 4, 32, 32, device=test_device)
|
||||
noised = scheduler.add_noise(x, noise, scheduler.steps[0])
|
||||
assert noised.device == test_device
|
||||
|
|
Loading…
Reference in a new issue