make scheduler an actual abstract base class

This commit is contained in:
Cédric Deltheil 2023-09-12 16:26:17 +02:00 committed by Cédric Deltheil
parent 12e37f5d85
commit 1b4dcebe06
2 changed files with 10 additions and 3 deletions

View file

@ -1,11 +1,11 @@
from abc import abstractmethod from abc import ABC, abstractmethod
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log
from typing import TypeVar from typing import TypeVar
T = TypeVar("T", bound="Scheduler") T = TypeVar("T", bound="Scheduler")
class Scheduler: class Scheduler(ABC):
""" """
A base class for creating a diffusion model scheduler. A base class for creating a diffusion model scheduler.

View file

@ -5,7 +5,14 @@ from torch import linspace, float32, randn, Tensor, allclose
def test_scheduler_utils(): def test_scheduler_utils():
scheduler = Scheduler(10, 20, 0.1, 0.2, "cpu") class DummyScheduler(Scheduler):
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
return Tensor()
def _generate_timesteps(self) -> Tensor:
return Tensor()
scheduler = DummyScheduler(10, 20, 0.1, 0.2, "cpu")
scale_factors = ( scale_factors = (
1.0 1.0
- linspace( - linspace(