make scheduler an actual abstract base class

This commit is contained in:
Cédric Deltheil 2023-09-12 16:26:17 +02:00 committed by Cédric Deltheil
parent 12e37f5d85
commit 1b4dcebe06
2 changed files with 10 additions and 3 deletions

View file

@ -1,11 +1,11 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log
from typing import TypeVar
T = TypeVar("T", bound="Scheduler")
class Scheduler:
class Scheduler(ABC):
"""
A base class for creating a diffusion model scheduler.

View file

@ -5,7 +5,14 @@ from torch import linspace, float32, randn, Tensor, allclose
def test_scheduler_utils():
scheduler = Scheduler(10, 20, 0.1, 0.2, "cpu")
class DummyScheduler(Scheduler):
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
return Tensor()
def _generate_timesteps(self) -> Tensor:
return Tensor()
scheduler = DummyScheduler(10, 20, 0.1, 0.2, "cpu")
scale_factors = (
1.0
- linspace(