add karras sampling to the Scheduler abstract class, default is quadratic

This commit is contained in:
limiteinductive 2023-12-03 17:11:36 +01:00 committed by Benjamin Trom
parent f22f969d65
commit ad8f02e555
3 changed files with 39 additions and 13 deletions

View file

@ -1,5 +1,5 @@
from torch import Tensor, device as Device, dtype as Dtype, arange, sqrt, float32, tensor
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
class DDIM(Scheduler):
@ -9,6 +9,7 @@ class DDIM(Scheduler):
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: Dtype = float32,
) -> None:
@ -17,6 +18,7 @@ class DDIM(Scheduler):
num_train_timesteps,
initial_diffusion_rate,
final_diffusion_rate,
noise_schedule=noise_schedule,
device=device,
dtype=dtype,
)

View file

@ -1,4 +1,4 @@
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
import numpy as np
from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype
from collections import deque
@ -16,6 +16,7 @@ class DPMSolver(Scheduler):
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: Dtype = float32,
):
@ -24,6 +25,7 @@ class DPMSolver(Scheduler):
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
device=device,
dtype=dtype,
)

View file

@ -1,10 +1,17 @@
from abc import ABC, abstractmethod
from enum import Enum
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log
from typing import TypeVar
T = TypeVar("T", bound="Scheduler")
class NoiseSchedule(str, Enum):
UNIFORM = "uniform"
QUADRATIC = "quadratic"
KARRAS = "karras"
class Scheduler(ABC):
"""
A base class for creating a diffusion model scheduler.
@ -24,6 +31,7 @@ class Scheduler(ABC):
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: DType = float32,
):
@ -33,17 +41,8 @@ class Scheduler(ABC):
self.num_train_timesteps = num_train_timesteps
self.initial_diffusion_rate = initial_diffusion_rate
self.final_diffusion_rate = final_diffusion_rate
self.scale_factors = (
1.0
- linspace(
start=initial_diffusion_rate**0.5,
end=final_diffusion_rate**0.5,
steps=num_train_timesteps,
device=device,
dtype=dtype,
)
** 2
)
self.noise_schedule = noise_schedule
self.scale_factors = self.sample_noise_schedule()
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
@ -71,6 +70,29 @@ class Scheduler(ABC):
def steps(self) -> list[int]:
return list(range(self.num_inference_steps))
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return (
linspace(
start=self.initial_diffusion_rate ** (1 / power),
end=self.final_diffusion_rate ** (1 / power),
steps=self.num_train_timesteps,
device=self.device,
dtype=self.dtype,
)
** power
)
def sample_noise_schedule(self) -> Tensor:
match self.noise_schedule:
case "uniform":
return 1 - self.sample_power_distribution(1)
case "quadratic":
return 1 - self.sample_power_distribution(2)
case "karras":
return 1 - self.sample_power_distribution(7)
case _:
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
def add_noise(
self,
x: Tensor,