From ddc1cf8ca7a1e2fd404c0db2a5ad2bd10a89336b Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Thu, 22 Feb 2024 12:02:58 +0100 Subject: [PATCH] refactor solvers to support different timestep spacings --- .../latent_diffusion/solvers/__init__.py | 4 +- .../latent_diffusion/solvers/ddim.py | 30 ++++--- .../latent_diffusion/solvers/ddpm.py | 15 ++-- .../latent_diffusion/solvers/dpm.py | 45 ++++++----- .../latent_diffusion/solvers/euler.py | 41 +++++----- .../latent_diffusion/solvers/lcm.py | 48 +++++++---- .../latent_diffusion/solvers/solver.py | 80 ++++++++++++++++--- .../latent_diffusion/test_solvers.py | 54 ++++++++++--- 8 files changed, 220 insertions(+), 97 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/solvers/__init__.py b/src/refiners/foundationals/latent_diffusion/solvers/__init__.py index 6a9aa07..475d959 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/__init__.py @@ -3,6 +3,6 @@ from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver from refiners.foundationals.latent_diffusion.solvers.euler import Euler from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver -from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver +from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing -__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule"] +__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule", "TimestepSpacing"] diff --git a/src/refiners/foundationals/latent_diffusion/solvers/ddim.py b/src/refiners/foundationals/latent_diffusion/solvers/ddim.py index 141d088..9ea146e 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/ddim.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/ddim.py @@ -1,6 +1,6 @@ -from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor +from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor -from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver +from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing class DDIM(Solver): @@ -13,6 +13,8 @@ class DDIM(Solver): self, num_inference_steps: int, num_train_timesteps: int = 1_000, + timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING, + timesteps_offset: int = 1, initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, @@ -25,6 +27,8 @@ class DDIM(Solver): Args: num_inference_steps: The number of inference steps. num_train_timesteps: The number of training timesteps. + timesteps_spacing: The spacing to use for the timesteps. + timesteps_offset: The offset to use for the timesteps. initial_diffusion_rate: The initial diffusion rate. final_diffusion_rate: The final diffusion rate. noise_schedule: The noise schedule. @@ -35,6 +39,8 @@ class DDIM(Solver): super().__init__( num_inference_steps=num_inference_steps, num_train_timesteps=num_train_timesteps, + timesteps_spacing=timesteps_spacing, + timesteps_offset=timesteps_offset, initial_diffusion_rate=initial_diffusion_rate, final_diffusion_rate=final_diffusion_rate, noise_schedule=noise_schedule, @@ -43,16 +49,18 @@ class DDIM(Solver): dtype=dtype, ) - def _generate_timesteps(self) -> Tensor: - """ - Generates decreasing timesteps with 'leading' spacing and offset of 1 - similar to diffusers settings for the DDIM solver in Stable Diffusion 1.5 - """ - step_ratio = self.num_train_timesteps // self.num_inference_steps - timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1 - return timesteps.flip(0) - def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: + """Apply one step of the backward diffusion process. + + Args: + x: The input tensor to apply the diffusion process to. + predicted_noise: The predicted noise tensor for the current step. + step: The current step of the diffusion process. + generator: The random number generator to use for sampling noise (ignored, this solver is deterministic). + + Returns: + The denoised version of the input data `x`. + """ assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" timestep, previous_timestep = ( diff --git a/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py b/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py index 49ae9c2..cd8547e 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py @@ -1,6 +1,6 @@ -from torch import Generator, Tensor, arange, device as Device +from torch import Generator, Tensor, device as Device -from refiners.foundationals.latent_diffusion.solvers.solver import Solver +from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing class DDPM(Solver): @@ -17,6 +17,8 @@ class DDPM(Solver): self, num_inference_steps: int, num_train_timesteps: int = 1_000, + timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING, + timesteps_offset: int = 0, initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, first_inference_step: int = 0, @@ -27,6 +29,8 @@ class DDPM(Solver): Args: num_inference_steps: The number of inference steps. num_train_timesteps: The number of training timesteps. + timesteps_spacing: The spacing to use for the timesteps. + timesteps_offset: The offset to use for the timesteps. initial_diffusion_rate: The initial diffusion rate. final_diffusion_rate: The final diffusion rate. first_inference_step: The first inference step. @@ -35,16 +39,13 @@ class DDPM(Solver): super().__init__( num_inference_steps=num_inference_steps, num_train_timesteps=num_train_timesteps, + timesteps_spacing=timesteps_spacing, + timesteps_offset=timesteps_offset, initial_diffusion_rate=initial_diffusion_rate, final_diffusion_rate=final_diffusion_rate, first_inference_step=first_inference_step, device=device, ) - def _generate_timesteps(self) -> Tensor: - step_ratio = self.num_train_timesteps // self.num_inference_steps - timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio - return timesteps.flip(0) - def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: raise NotImplementedError diff --git a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py index 64351a7..5ef3bfd 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py @@ -1,9 +1,8 @@ from collections import deque -import numpy as np from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor -from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver +from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing class DPMSolver(Solver): @@ -23,6 +22,8 @@ class DPMSolver(Solver): self, num_inference_steps: int, num_train_timesteps: int = 1_000, + timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING_ALT, + timesteps_offset: int = 0, initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, last_step_first_order: bool = False, @@ -31,9 +32,26 @@ class DPMSolver(Solver): device: Device | str = "cpu", dtype: Dtype = float32, ): + """Initializes a new DPM solver. + + Args: + num_inference_steps: The number of inference steps. + num_train_timesteps: The number of training timesteps. + timesteps_spacing: The spacing to use for the timesteps. + timesteps_offset: The offset to use for the timesteps. + initial_diffusion_rate: The initial diffusion rate. + final_diffusion_rate: The final diffusion rate. + last_step_first_order: Use a first-order update for the last step. + noise_schedule: The noise schedule. + first_inference_step: The first inference step. + device: The PyTorch device to use. + dtype: The PyTorch data type to use. + """ super().__init__( num_inference_steps=num_inference_steps, num_train_timesteps=num_train_timesteps, + timesteps_spacing=timesteps_spacing, + timesteps_offset=timesteps_offset, initial_diffusion_rate=initial_diffusion_rate, final_diffusion_rate=final_diffusion_rate, noise_schedule=noise_schedule, @@ -44,21 +62,6 @@ class DPMSolver(Solver): self.estimated_data = deque([tensor([])] * 2, maxlen=2) self.last_step_first_order = last_step_first_order - def _generate_timesteps(self) -> Tensor: - """Generate the timesteps used by the solver. - - Note: - We need to use numpy here because: - - - numpy.linspace(0,999,31)[15] is 499.49999999999994 - - torch.linspace(0,999,31)[15] is 499.5 - - and we want the same result as the original codebase. - """ - return tensor( - np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:], - ).flip(0) - def rebuild( self: "DPMSolver", num_inference_steps: int | None, @@ -148,10 +151,10 @@ class DPMSolver(Solver): (ODEs). Args: - x: The input data. - predicted_noise: The predicted noise. - step: The current step. - generator: The random number generator. + x: The input tensor to apply the diffusion process to. + predicted_noise: The predicted noise tensor for the current step. + step: The current step of the diffusion process. + generator: The random number generator to use for sampling noise (ignored, this solver is deterministic). Returns: The denoised version of the input data `x`. diff --git a/src/refiners/foundationals/latent_diffusion/solvers/euler.py b/src/refiners/foundationals/latent_diffusion/solvers/euler.py index cef8897..fc2ef73 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/euler.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/euler.py @@ -2,7 +2,7 @@ import numpy as np import torch from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor -from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver +from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing class Euler(Solver): @@ -16,6 +16,8 @@ class Euler(Solver): self, num_inference_steps: int, num_train_timesteps: int = 1_000, + timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT, + timesteps_offset: int = 0, initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, @@ -28,6 +30,8 @@ class Euler(Solver): Args: num_inference_steps: The number of inference steps. num_train_timesteps: The number of training timesteps. + timesteps_spacing: The spacing to use for the timesteps. + timesteps_offset: The offset to use for the timesteps. initial_diffusion_rate: The initial diffusion rate. final_diffusion_rate: The final diffusion rate. noise_schedule: The noise schedule. @@ -40,6 +44,8 @@ class Euler(Solver): super().__init__( num_inference_steps=num_inference_steps, num_train_timesteps=num_train_timesteps, + timesteps_spacing=timesteps_spacing, + timesteps_offset=timesteps_offset, initial_diffusion_rate=initial_diffusion_rate, final_diffusion_rate=final_diffusion_rate, noise_schedule=noise_schedule, @@ -54,20 +60,6 @@ class Euler(Solver): """The initial noise sigma.""" return self.sigmas.max() - def _generate_timesteps(self) -> Tensor: - """Generate the timesteps used by the solver. - - Note: - We need to use numpy here because: - - - numpy.linspace(0,999,31)[15] is 499.49999999999994 - - torch.linspace(0,999,31)[15] is 499.5 - - and we want the same result as the original codebase. - """ - timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0) - return timesteps - def _generate_sigmas(self) -> Tensor: """Generate the sigmas used by the solver.""" sigmas = self.noise_std / self.cumulative_scale_factors @@ -88,12 +80,17 @@ class Euler(Solver): sigma = self.sigmas[step] return x / ((sigma**2 + 1) ** 0.5) - def __call__( - self, - x: Tensor, - predicted_noise: Tensor, - step: int, - generator: Generator | None = None, - ) -> Tensor: + def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: + """Apply one step of the backward diffusion process. + + Args: + x: The input tensor to apply the diffusion process to. + predicted_noise: The predicted noise tensor for the current step. + step: The current step of the diffusion process. + generator: The random number generator to use for sampling noise (ignored, this solver is deterministic). + + Returns: + The denoised version of the input data `x`. + """ assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step]) diff --git a/src/refiners/foundationals/latent_diffusion/solvers/lcm.py b/src/refiners/foundationals/latent_diffusion/solvers/lcm.py index 00d6b50..0ffd581 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/lcm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/lcm.py @@ -1,8 +1,7 @@ -import numpy as np import torch from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver -from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver +from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing class LCMSolver(Solver): @@ -20,14 +19,29 @@ class LCMSolver(Solver): self, num_inference_steps: int, num_train_timesteps: int = 1_000, + timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING, + timesteps_offset: int = 0, num_orig_steps: int = 50, initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, - diffusers_mode: bool = False, device: torch.device | str = "cpu", dtype: torch.dtype = torch.float32, ): + """Initializes a new LCM solver. + + Args: + num_inference_steps: The number of inference steps. + num_train_timesteps: The number of training timesteps. + timesteps_spacing: The spacing to use for the timesteps. + timesteps_offset: The offset to use for the timesteps. + num_orig_steps: The number of inference steps of the emulated DPM solver. + initial_diffusion_rate: The initial diffusion rate. + final_diffusion_rate: The final diffusion rate. + noise_schedule: The noise schedule. + device: The PyTorch device to use. + dtype: The PyTorch data type to use. + """ assert ( num_orig_steps >= num_inference_steps ), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})" @@ -36,22 +50,17 @@ class LCMSolver(Solver): DPMSolver( num_inference_steps=num_orig_steps, num_train_timesteps=num_train_timesteps, + timesteps_spacing=timesteps_spacing, device=device, dtype=dtype, ) ] - if diffusers_mode: - # Diffusers recomputes the timesteps in LCMScheduler, - # and it does it slightly differently than DPM Solver. - # We provide this option to reproduce Diffusers' output. - k = num_train_timesteps // num_orig_steps - ts = np.asarray(list(range(1, num_orig_steps + 1))) * k - 1 - self.dpm.timesteps = torch.tensor(ts, device=device).flip(0) - super().__init__( num_inference_steps=num_inference_steps, num_train_timesteps=num_train_timesteps, + timesteps_spacing=timesteps_spacing, + timesteps_offset=timesteps_offset, initial_diffusion_rate=initial_diffusion_rate, final_diffusion_rate=final_diffusion_rate, noise_schedule=noise_schedule, @@ -85,12 +94,19 @@ class LCMSolver(Solver): return self.dpm.timesteps[self.timestep_indices] def __call__( - self, - x: torch.Tensor, - predicted_noise: torch.Tensor, - step: int, - generator: torch.Generator | None = None, + self, x: torch.Tensor, predicted_noise: torch.Tensor, step: int, generator: torch.Generator | None = None ) -> torch.Tensor: + """Apply one step of the backward diffusion process. + + Args: + x: The input tensor to apply the diffusion process to. + predicted_noise: The predicted noise tensor for the current step. + step: The current step of the diffusion process. + generator: The random number generator to use for sampling noise. + + Returns: + The denoised version of the input data `x`. + """ current_timestep = self.timesteps[step] scale_factor = self.cumulative_scale_factors[current_timestep] noise_ratio = self.noise_std[current_timestep] diff --git a/src/refiners/foundationals/latent_diffusion/solvers/solver.py b/src/refiners/foundationals/latent_diffusion/solvers/solver.py index d11dc68..32ef32d 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/solver.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/solver.py @@ -2,7 +2,8 @@ from abc import ABC, abstractmethod from enum import Enum from typing import TypeVar -from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt +import numpy as np +from torch import Generator, Tensor, arange, device as Device, dtype as DType, float32, linspace, log, sqrt, tensor from refiners.fluxion import layers as fl @@ -10,11 +11,11 @@ T = TypeVar("T", bound="Solver") class NoiseSchedule(str, Enum): - """An enumeration of noise schedules used to sample the noise schedule. + """An enumeration of schedules used to sample the noise. Attributes: UNIFORM: A uniform noise schedule. - QUADRATIC: A quadratic noise schedule. + QUADRATIC: A quadratic noise schedule. Corresponds to "Stable Diffusion" in [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 1. KARRAS: See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models, Equation 5](https://arxiv.org/abs/2206.00364) """ @@ -23,6 +24,26 @@ class NoiseSchedule(str, Enum): KARRAS = "karras" +class TimestepSpacing(str, Enum): + """An enumeration of methods to space the timesteps. + + See [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2. + + Attributes: + LINSPACE_FLOAT: Sample N steps with linear interpolation, return a floating-point tensor. + LINSPACE_INT: Same as LINSPACE_FLOAT but return an integer tensor with rounded timesteps. + LEADING: Sample N+1 steps, do not include the last timestep (i.e. bad - non-zero SNR). Used in DDIM. + TRAILING: Sample N+1 steps, do not include the first timestep. + TRAILING_ALT: Variant of TRAILING used in DPM. + """ + + LINSPACE_FLOAT = "linspace_float" + LINSPACE_INT = "linspace_int" + LEADING = "leading" + TRAILING = "trailing" + TRAILING_ALT = "trailing_alt" + + class Solver(fl.Module, ABC): """The base class for creating a diffusion model solver. @@ -39,6 +60,8 @@ class Solver(fl.Module, ABC): self, num_inference_steps: int, num_train_timesteps: int = 1_000, + timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT, + timesteps_offset: int = 0, initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, @@ -51,6 +74,8 @@ class Solver(fl.Module, ABC): Args: num_inference_steps: The number of inference steps to perform. num_train_timesteps: The number of timesteps used to train the diffusion process. + timesteps_spacing: The spacing to use for the timesteps. + timesteps_offset: The offset to use for the timesteps. initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule. final_diffusion_rate: The final diffusion rate used to sample the noise schedule. noise_schedule: The noise schedule used to sample the noise schedule. @@ -61,6 +86,8 @@ class Solver(fl.Module, ABC): super().__init__() self.num_inference_steps = num_inference_steps self.num_train_timesteps = num_train_timesteps + self.timesteps_spacing = timesteps_spacing + self.timesteps_offset = timesteps_offset self.initial_diffusion_rate = initial_diffusion_rate self.final_diffusion_rate = final_diffusion_rate self.noise_schedule = noise_schedule @@ -87,14 +114,49 @@ class Solver(fl.Module, ABC): """ ... - @abstractmethod - def _generate_timesteps(self) -> Tensor: - """Generate a tensor of timesteps. + @staticmethod + def generate_timesteps( + spacing: TimestepSpacing, + num_inference_steps: int, + num_train_timesteps: int = 1000, + offset: int = 0, + ) -> Tensor: + """Generate a tensor of timesteps according to a given spacing. - Note: - This method should be overridden by subclasses to provide the specific timesteps for the diffusion process. + Args: + spacing: The spacing to use for the timesteps. + num_inference_steps: The number of inference steps to perform. + num_train_timesteps: The number of timesteps used to train the diffusion process. + offset: The offset to use for the timesteps. """ - ... + max_timestep = num_train_timesteps - 1 + offset + match spacing: + case TimestepSpacing.LINSPACE_FLOAT: + return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0) + case TimestepSpacing.LINSPACE_INT: + return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0) + case TimestepSpacing.LEADING: + step_ratio = num_train_timesteps // num_inference_steps + return (arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0) + case TimestepSpacing.TRAILING: + step_ratio = num_train_timesteps // num_inference_steps + max_timestep = num_train_timesteps - 1 + offset + return arange(max_timestep, offset, -step_ratio) + case TimestepSpacing.TRAILING_ALT: + # We use numpy here because: + # numpy.linspace(0,999,31)[15] is 499.49999999999994 + # torch.linspace(0,999,31)[15] is 499.5 + # and we want the same result as the original DPM codebase. + np_space = np.linspace(offset, max_timestep, num_inference_steps + 1).round().astype(int)[1:] + return tensor(np_space).flip(0) + + def _generate_timesteps(self) -> Tensor: + return self.generate_timesteps( + spacing=self.timesteps_spacing, + num_inference_steps=self.num_inference_steps, + num_train_timesteps=self.num_train_timesteps, + offset=self.timesteps_offset, + ) def add_noise( self, diff --git a/tests/foundationals/latent_diffusion/test_solvers.py b/tests/foundationals/latent_diffusion/test_solvers.py index 97444c3..109494a 100644 --- a/tests/foundationals/latent_diffusion/test_solvers.py +++ b/tests/foundationals/latent_diffusion/test_solvers.py @@ -2,10 +2,19 @@ from typing import cast from warnings import warn import pytest -from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn +from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn, tensor from refiners.fluxion import manual_seed -from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, LCMSolver, NoiseSchedule +from refiners.foundationals.latent_diffusion.solvers import ( + DDIM, + DDPM, + DPMSolver, + Euler, + LCMSolver, + NoiseSchedule, + Solver, + TimestepSpacing, +) def test_ddpm_diffusers(): @@ -14,7 +23,6 @@ def test_ddpm_diffusers(): diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012) diffusers_scheduler.set_timesteps(1000) refiners_scheduler = DDPM(num_inference_steps=1000) - assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps) @@ -34,6 +42,7 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool): ) diffusers_scheduler.set_timesteps(n_steps) refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order) + assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps) sample = randn(1, 3, 32, 32) predicted_noise = randn(1, 3, 32, 32) @@ -59,6 +68,7 @@ def test_ddim_diffusers(): ) diffusers_scheduler.set_timesteps(30) refiners_scheduler = DDIM(num_inference_steps=30) + assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps) sample = randn(1, 4, 32, 32) predicted_noise = randn(1, 4, 32, 32) @@ -85,6 +95,7 @@ def test_euler_diffusers(): ) diffusers_scheduler.set_timesteps(30) refiners_scheduler = Euler(num_inference_steps=30) + assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps) sample = randn(1, 4, 32, 32) predicted_noise = randn(1, 4, 32, 32) @@ -111,9 +122,7 @@ def test_lcm_diffusers(): diffusers_scheduler = LCMScheduler() diffusers_scheduler.set_timesteps(4) - refiners_scheduler = LCMSolver(num_inference_steps=4, diffusers_mode=True) - - # diffusers_mode means the timesteps are the same + refiners_scheduler = LCMSolver(num_inference_steps=4) assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps) sample = randn(1, 4, 32, 32) @@ -143,7 +152,7 @@ def test_lcm_diffusers(): assert allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}" -def test_scheduler_remove_noise(): +def test_solver_remove_noise(): from diffusers import DDIMScheduler # type: ignore manual_seed(0) @@ -169,7 +178,7 @@ def test_scheduler_remove_noise(): assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}" -def test_scheduler_device(test_device: Device): +def test_solver_device(test_device: Device): if test_device.type == "cpu": warn("not running on CPU, skipping") pytest.skip() @@ -182,8 +191,35 @@ def test_scheduler_device(test_device: Device): @pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS]) -def test_scheduler_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device): +def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device): scheduler = DDIM(num_inference_steps=30, device=test_device, noise_schedule=noise_schedule) assert len(scheduler.scale_factors) == 1000 assert scheduler.scale_factors[0] == 1 - scheduler.initial_diffusion_rate assert scheduler.scale_factors[-1] == 1 - scheduler.final_diffusion_rate + + +def test_solver_timestep_spacing(): + # Tests we get the results from [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2. + linspace_int = Solver.generate_timesteps( + spacing=TimestepSpacing.LINSPACE_INT, + num_inference_steps=10, + num_train_timesteps=1000, + offset=1, + ) + assert equal(linspace_int, tensor([1000, 889, 778, 667, 556, 445, 334, 223, 112, 1])) + + leading = Solver.generate_timesteps( + spacing=TimestepSpacing.LEADING, + num_inference_steps=10, + num_train_timesteps=1000, + offset=1, + ) + assert equal(leading, tensor([901, 801, 701, 601, 501, 401, 301, 201, 101, 1])) + + trailing = Solver.generate_timesteps( + spacing=TimestepSpacing.TRAILING, + num_inference_steps=10, + num_train_timesteps=1000, + offset=1, + ) + assert equal(trailing, tensor([1000, 900, 800, 700, 600, 500, 400, 300, 200, 100]))