add karras sampling to the Scheduler abstract class, default is quadratic

This commit is contained in:
limiteinductive 2023-12-03 17:11:36 +01:00 committed by Benjamin Trom
parent f22f969d65
commit ad8f02e555
3 changed files with 39 additions and 13 deletions

View file

@ -1,5 +1,5 @@
from torch import Tensor, device as Device, dtype as Dtype, arange, sqrt, float32, tensor from torch import Tensor, device as Device, dtype as Dtype, arange, sqrt, float32, tensor
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
class DDIM(Scheduler): class DDIM(Scheduler):
@ -9,6 +9,7 @@ class DDIM(Scheduler):
num_train_timesteps: int = 1_000, num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
) -> None: ) -> None:
@ -17,6 +18,7 @@ class DDIM(Scheduler):
num_train_timesteps, num_train_timesteps,
initial_diffusion_rate, initial_diffusion_rate,
final_diffusion_rate, final_diffusion_rate,
noise_schedule=noise_schedule,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )

View file

@ -1,4 +1,4 @@
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
import numpy as np import numpy as np
from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype
from collections import deque from collections import deque
@ -16,6 +16,7 @@ class DPMSolver(Scheduler):
num_train_timesteps: int = 1_000, num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
): ):
@ -24,6 +25,7 @@ class DPMSolver(Scheduler):
num_train_timesteps=num_train_timesteps, num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )

View file

@ -1,10 +1,17 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log
from typing import TypeVar from typing import TypeVar
T = TypeVar("T", bound="Scheduler") T = TypeVar("T", bound="Scheduler")
class NoiseSchedule(str, Enum):
UNIFORM = "uniform"
QUADRATIC = "quadratic"
KARRAS = "karras"
class Scheduler(ABC): class Scheduler(ABC):
""" """
A base class for creating a diffusion model scheduler. A base class for creating a diffusion model scheduler.
@ -24,6 +31,7 @@ class Scheduler(ABC):
num_train_timesteps: int = 1_000, num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = float32, dtype: DType = float32,
): ):
@ -33,17 +41,8 @@ class Scheduler(ABC):
self.num_train_timesteps = num_train_timesteps self.num_train_timesteps = num_train_timesteps
self.initial_diffusion_rate = initial_diffusion_rate self.initial_diffusion_rate = initial_diffusion_rate
self.final_diffusion_rate = final_diffusion_rate self.final_diffusion_rate = final_diffusion_rate
self.scale_factors = ( self.noise_schedule = noise_schedule
1.0 self.scale_factors = self.sample_noise_schedule()
- linspace(
start=initial_diffusion_rate**0.5,
end=final_diffusion_rate**0.5,
steps=num_train_timesteps,
device=device,
dtype=dtype,
)
** 2
)
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0)) self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0)) self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std) self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
@ -71,6 +70,29 @@ class Scheduler(ABC):
def steps(self) -> list[int]: def steps(self) -> list[int]:
return list(range(self.num_inference_steps)) return list(range(self.num_inference_steps))
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return (
linspace(
start=self.initial_diffusion_rate ** (1 / power),
end=self.final_diffusion_rate ** (1 / power),
steps=self.num_train_timesteps,
device=self.device,
dtype=self.dtype,
)
** power
)
def sample_noise_schedule(self) -> Tensor:
match self.noise_schedule:
case "uniform":
return 1 - self.sample_power_distribution(1)
case "quadratic":
return 1 - self.sample_power_distribution(2)
case "karras":
return 1 - self.sample_power_distribution(7)
case _:
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
def add_noise( def add_noise(
self, self,
x: Tensor, x: Tensor,