diff --git a/tests/foundationals/latent_diffusion/test_schedulers.py b/tests/foundationals/latent_diffusion/test_schedulers.py index 01552f0..b4f371a 100644 --- a/tests/foundationals/latent_diffusion/test_schedulers.py +++ b/tests/foundationals/latent_diffusion/test_schedulers.py @@ -1,37 +1,9 @@ import pytest from typing import cast from warnings import warn -from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver, DDIM -from refiners.fluxion import norm, manual_seed -from torch import linspace, float32, randn, Tensor, allclose, device as Device - - -def test_scheduler_utils(): - class DummyScheduler(Scheduler): - def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor: - return Tensor() - - def _generate_timesteps(self) -> Tensor: - return Tensor() - - scheduler = DummyScheduler( - num_inference_steps=10, - num_train_timesteps=20, - initial_diffusion_rate=0.1, - final_diffusion_rate=0.2, - device="cpu", - ) - scale_factors = ( - 1.0 - - linspace( - start=0.1**0.5, - end=0.2**0.5, - steps=20, - dtype=float32, - ) - ** 2 - ) - assert norm(scheduler.scale_factors - scale_factors) == 0 +from refiners.foundationals.latent_diffusion.schedulers import DPMSolver, DDIM +from refiners.fluxion import manual_seed +from torch import randn, Tensor, allclose, device as Device def test_dpm_solver_diffusers():