mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
fix scheduler device choice
This commit is contained in:
parent
282578ddc0
commit
cd1fdb5585
|
@ -39,6 +39,7 @@ class Scheduler(ABC):
|
||||||
start=initial_diffusion_rate**0.5,
|
start=initial_diffusion_rate**0.5,
|
||||||
end=final_diffusion_rate**0.5,
|
end=final_diffusion_rate**0.5,
|
||||||
steps=num_train_timesteps,
|
steps=num_train_timesteps,
|
||||||
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
** 2
|
** 2
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
|
import pytest
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
from warnings import warn
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver, DDIM
|
from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver, DDIM
|
||||||
from refiners.fluxion import norm, manual_seed
|
from refiners.fluxion import norm, manual_seed
|
||||||
from torch import linspace, float32, randn, Tensor, allclose
|
from torch import linspace, float32, randn, Tensor, allclose, device as Device
|
||||||
|
|
||||||
|
|
||||||
def test_scheduler_utils():
|
def test_scheduler_utils():
|
||||||
|
@ -67,3 +69,15 @@ def test_ddim_solver_diffusers():
|
||||||
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
|
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
|
||||||
|
|
||||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_device(test_device: Device):
|
||||||
|
if test_device.type == "cpu":
|
||||||
|
warn("not running on CPU, skipping")
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
|
scheduler = DDIM(num_inference_steps=30, device=test_device)
|
||||||
|
x = randn(1, 4, 32, 32, device=test_device)
|
||||||
|
noise = randn(1, 4, 32, 32, device=test_device)
|
||||||
|
noised = scheduler.add_noise(x, noise, scheduler.steps[0])
|
||||||
|
assert noised.device == test_device
|
||||||
|
|
Loading…
Reference in a new issue