mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
make scheduler an actual abstract base class
This commit is contained in:
parent
12e37f5d85
commit
1b4dcebe06
|
@ -1,11 +1,11 @@
|
|||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T", bound="Scheduler")
|
||||
|
||||
|
||||
class Scheduler:
|
||||
class Scheduler(ABC):
|
||||
"""
|
||||
A base class for creating a diffusion model scheduler.
|
||||
|
||||
|
|
|
@ -5,7 +5,14 @@ from torch import linspace, float32, randn, Tensor, allclose
|
|||
|
||||
|
||||
def test_scheduler_utils():
|
||||
scheduler = Scheduler(10, 20, 0.1, 0.2, "cpu")
|
||||
class DummyScheduler(Scheduler):
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||
return Tensor()
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
return Tensor()
|
||||
|
||||
scheduler = DummyScheduler(10, 20, 0.1, 0.2, "cpu")
|
||||
scale_factors = (
|
||||
1.0
|
||||
- linspace(
|
||||
|
|
Loading…
Reference in a new issue