mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
make scheduler an actual abstract base class
This commit is contained in:
parent
12e37f5d85
commit
1b4dcebe06
|
@ -1,11 +1,11 @@
|
||||||
from abc import abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log
|
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
T = TypeVar("T", bound="Scheduler")
|
T = TypeVar("T", bound="Scheduler")
|
||||||
|
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler(ABC):
|
||||||
"""
|
"""
|
||||||
A base class for creating a diffusion model scheduler.
|
A base class for creating a diffusion model scheduler.
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,14 @@ from torch import linspace, float32, randn, Tensor, allclose
|
||||||
|
|
||||||
|
|
||||||
def test_scheduler_utils():
|
def test_scheduler_utils():
|
||||||
scheduler = Scheduler(10, 20, 0.1, 0.2, "cpu")
|
class DummyScheduler(Scheduler):
|
||||||
|
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
|
return Tensor()
|
||||||
|
|
||||||
|
def _generate_timesteps(self) -> Tensor:
|
||||||
|
return Tensor()
|
||||||
|
|
||||||
|
scheduler = DummyScheduler(10, 20, 0.1, 0.2, "cpu")
|
||||||
scale_factors = (
|
scale_factors = (
|
||||||
1.0
|
1.0
|
||||||
- linspace(
|
- linspace(
|
||||||
|
|
Loading…
Reference in a new issue