From 1b4dcebe066d8aa2defec95483a8cdef0acd8fdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Tue, 12 Sep 2023 16:26:17 +0200 Subject: [PATCH] make scheduler an actual abstract base class --- .../latent_diffusion/schedulers/scheduler.py | 4 ++-- tests/foundationals/latent_diffusion/test_schedulers.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py index 26cd6dd..a50ade0 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py @@ -1,11 +1,11 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log from typing import TypeVar T = TypeVar("T", bound="Scheduler") -class Scheduler: +class Scheduler(ABC): """ A base class for creating a diffusion model scheduler. diff --git a/tests/foundationals/latent_diffusion/test_schedulers.py b/tests/foundationals/latent_diffusion/test_schedulers.py index 0864602..46f7bc2 100644 --- a/tests/foundationals/latent_diffusion/test_schedulers.py +++ b/tests/foundationals/latent_diffusion/test_schedulers.py @@ -5,7 +5,14 @@ from torch import linspace, float32, randn, Tensor, allclose def test_scheduler_utils(): - scheduler = Scheduler(10, 20, 0.1, 0.2, "cpu") + class DummyScheduler(Scheduler): + def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor: + return Tensor() + + def _generate_timesteps(self) -> Tensor: + return Tensor() + + scheduler = DummyScheduler(10, 20, 0.1, 0.2, "cpu") scale_factors = ( 1.0 - linspace(