diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py index 26cd6dd..a50ade0 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py @@ -1,11 +1,11 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log from typing import TypeVar T = TypeVar("T", bound="Scheduler") -class Scheduler: +class Scheduler(ABC): """ A base class for creating a diffusion model scheduler. diff --git a/tests/foundationals/latent_diffusion/test_schedulers.py b/tests/foundationals/latent_diffusion/test_schedulers.py index 0864602..46f7bc2 100644 --- a/tests/foundationals/latent_diffusion/test_schedulers.py +++ b/tests/foundationals/latent_diffusion/test_schedulers.py @@ -5,7 +5,14 @@ from torch import linspace, float32, randn, Tensor, allclose def test_scheduler_utils(): - scheduler = Scheduler(10, 20, 0.1, 0.2, "cpu") + class DummyScheduler(Scheduler): + def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor: + return Tensor() + + def _generate_timesteps(self) -> Tensor: + return Tensor() + + scheduler = DummyScheduler(10, 20, 0.1, 0.2, "cpu") scale_factors = ( 1.0 - linspace(