mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
add karras sampling to the Scheduler abstract class, default is quadratic
This commit is contained in:
parent
f22f969d65
commit
ad8f02e555
|
@ -1,5 +1,5 @@
|
|||
from torch import Tensor, device as Device, dtype as Dtype, arange, sqrt, float32, tensor
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||
|
||||
|
||||
class DDIM(Scheduler):
|
||||
|
@ -9,6 +9,7 @@ class DDIM(Scheduler):
|
|||
num_train_timesteps: int = 1_000,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
) -> None:
|
||||
|
@ -17,6 +18,7 @@ class DDIM(Scheduler):
|
|||
num_train_timesteps,
|
||||
initial_diffusion_rate,
|
||||
final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||
import numpy as np
|
||||
from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype
|
||||
from collections import deque
|
||||
|
@ -16,6 +16,7 @@ class DPMSolver(Scheduler):
|
|||
num_train_timesteps: int = 1_000,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
):
|
||||
|
@ -24,6 +25,7 @@ class DPMSolver(Scheduler):
|
|||
num_train_timesteps=num_train_timesteps,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
|
@ -1,10 +1,17 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T", bound="Scheduler")
|
||||
|
||||
|
||||
class NoiseSchedule(str, Enum):
|
||||
UNIFORM = "uniform"
|
||||
QUADRATIC = "quadratic"
|
||||
KARRAS = "karras"
|
||||
|
||||
|
||||
class Scheduler(ABC):
|
||||
"""
|
||||
A base class for creating a diffusion model scheduler.
|
||||
|
@ -24,6 +31,7 @@ class Scheduler(ABC):
|
|||
num_train_timesteps: int = 1_000,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = float32,
|
||||
):
|
||||
|
@ -33,17 +41,8 @@ class Scheduler(ABC):
|
|||
self.num_train_timesteps = num_train_timesteps
|
||||
self.initial_diffusion_rate = initial_diffusion_rate
|
||||
self.final_diffusion_rate = final_diffusion_rate
|
||||
self.scale_factors = (
|
||||
1.0
|
||||
- linspace(
|
||||
start=initial_diffusion_rate**0.5,
|
||||
end=final_diffusion_rate**0.5,
|
||||
steps=num_train_timesteps,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
self.noise_schedule = noise_schedule
|
||||
self.scale_factors = self.sample_noise_schedule()
|
||||
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
|
||||
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
||||
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
|
||||
|
@ -71,6 +70,29 @@ class Scheduler(ABC):
|
|||
def steps(self) -> list[int]:
|
||||
return list(range(self.num_inference_steps))
|
||||
|
||||
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
|
||||
return (
|
||||
linspace(
|
||||
start=self.initial_diffusion_rate ** (1 / power),
|
||||
end=self.final_diffusion_rate ** (1 / power),
|
||||
steps=self.num_train_timesteps,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
** power
|
||||
)
|
||||
|
||||
def sample_noise_schedule(self) -> Tensor:
|
||||
match self.noise_schedule:
|
||||
case "uniform":
|
||||
return 1 - self.sample_power_distribution(1)
|
||||
case "quadratic":
|
||||
return 1 - self.sample_power_distribution(2)
|
||||
case "karras":
|
||||
return 1 - self.sample_power_distribution(7)
|
||||
case _:
|
||||
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
x: Tensor,
|
||||
|
|
Loading…
Reference in a new issue