diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py index f873d53..9c1bf7f 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py @@ -1,5 +1,5 @@ from torch import Tensor, device as Device, dtype as Dtype, arange, sqrt, float32, tensor -from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler +from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler class DDIM(Scheduler): @@ -9,6 +9,7 @@ class DDIM(Scheduler): num_train_timesteps: int = 1_000, initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, + noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, device: Device | str = "cpu", dtype: Dtype = float32, ) -> None: @@ -17,6 +18,7 @@ class DDIM(Scheduler): num_train_timesteps, initial_diffusion_rate, final_diffusion_rate, + noise_schedule=noise_schedule, device=device, dtype=dtype, ) diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py index 438a964..fe4d7e3 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py @@ -1,4 +1,4 @@ -from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler +from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler import numpy as np from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype from collections import deque @@ -16,6 +16,7 @@ class DPMSolver(Scheduler): num_train_timesteps: int = 1_000, initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, + noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, device: Device | str = "cpu", dtype: Dtype = float32, ): @@ -24,6 +25,7 @@ class DPMSolver(Scheduler): num_train_timesteps=num_train_timesteps, initial_diffusion_rate=initial_diffusion_rate, final_diffusion_rate=final_diffusion_rate, + noise_schedule=noise_schedule, device=device, dtype=dtype, ) diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py index 6e47672..e0d1fba 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py @@ -1,10 +1,17 @@ from abc import ABC, abstractmethod +from enum import Enum from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log from typing import TypeVar T = TypeVar("T", bound="Scheduler") +class NoiseSchedule(str, Enum): + UNIFORM = "uniform" + QUADRATIC = "quadratic" + KARRAS = "karras" + + class Scheduler(ABC): """ A base class for creating a diffusion model scheduler. @@ -24,6 +31,7 @@ class Scheduler(ABC): num_train_timesteps: int = 1_000, initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, + noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, device: Device | str = "cpu", dtype: DType = float32, ): @@ -33,17 +41,8 @@ class Scheduler(ABC): self.num_train_timesteps = num_train_timesteps self.initial_diffusion_rate = initial_diffusion_rate self.final_diffusion_rate = final_diffusion_rate - self.scale_factors = ( - 1.0 - - linspace( - start=initial_diffusion_rate**0.5, - end=final_diffusion_rate**0.5, - steps=num_train_timesteps, - device=device, - dtype=dtype, - ) - ** 2 - ) + self.noise_schedule = noise_schedule + self.scale_factors = self.sample_noise_schedule() self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0)) self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0)) self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std) @@ -71,6 +70,29 @@ class Scheduler(ABC): def steps(self) -> list[int]: return list(range(self.num_inference_steps)) + def sample_power_distribution(self, power: float = 2, /) -> Tensor: + return ( + linspace( + start=self.initial_diffusion_rate ** (1 / power), + end=self.final_diffusion_rate ** (1 / power), + steps=self.num_train_timesteps, + device=self.device, + dtype=self.dtype, + ) + ** power + ) + + def sample_noise_schedule(self) -> Tensor: + match self.noise_schedule: + case "uniform": + return 1 - self.sample_power_distribution(1) + case "quadratic": + return 1 - self.sample_power_distribution(2) + case "karras": + return 1 - self.sample_power_distribution(7) + case _: + raise ValueError(f"Unknown noise schedule: {self.noise_schedule}") + def add_noise( self, x: Tensor,