refactor solvers to support different timestep spacings

This commit is contained in:
Pierre Chapuis 2024-02-22 12:02:58 +01:00
parent d14c5bd5f8
commit ddc1cf8ca7
8 changed files with 220 additions and 97 deletions

View file

@ -3,6 +3,6 @@ from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule"]
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule", "TimestepSpacing"]

View file

@ -1,6 +1,6 @@
from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
class DDIM(Solver):
@ -13,6 +13,8 @@ class DDIM(Solver):
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
timesteps_offset: int = 1,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
@ -25,6 +27,8 @@ class DDIM(Solver):
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule.
@ -35,6 +39,8 @@ class DDIM(Solver):
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
@ -43,16 +49,18 @@ class DDIM(Solver):
dtype=dtype,
)
def _generate_timesteps(self) -> Tensor:
"""
Generates decreasing timesteps with 'leading' spacing and offset of 1
similar to diffusers settings for the DDIM solver in Stable Diffusion 1.5
"""
step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1
return timesteps.flip(0)
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""Apply one step of the backward diffusion process.
Args:
x: The input tensor to apply the diffusion process to.
predicted_noise: The predicted noise tensor for the current step.
step: The current step of the diffusion process.
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
Returns:
The denoised version of the input data `x`.
"""
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
timestep, previous_timestep = (

View file

@ -1,6 +1,6 @@
from torch import Generator, Tensor, arange, device as Device
from torch import Generator, Tensor, device as Device
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
class DDPM(Solver):
@ -17,6 +17,8 @@ class DDPM(Solver):
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
first_inference_step: int = 0,
@ -27,6 +29,8 @@ class DDPM(Solver):
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
first_inference_step: The first inference step.
@ -35,16 +39,13 @@ class DDPM(Solver):
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
first_inference_step=first_inference_step,
device=device,
)
def _generate_timesteps(self) -> Tensor:
step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
return timesteps.flip(0)
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
raise NotImplementedError

View file

@ -1,9 +1,8 @@
from collections import deque
import numpy as np
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
class DPMSolver(Solver):
@ -23,6 +22,8 @@ class DPMSolver(Solver):
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING_ALT,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
last_step_first_order: bool = False,
@ -31,9 +32,26 @@ class DPMSolver(Solver):
device: Device | str = "cpu",
dtype: Dtype = float32,
):
"""Initializes a new DPM solver.
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
last_step_first_order: Use a first-order update for the last step.
noise_schedule: The noise schedule.
first_inference_step: The first inference step.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
@ -44,21 +62,6 @@ class DPMSolver(Solver):
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
self.last_step_first_order = last_step_first_order
def _generate_timesteps(self) -> Tensor:
"""Generate the timesteps used by the solver.
Note:
We need to use numpy here because:
- numpy.linspace(0,999,31)[15] is 499.49999999999994
- torch.linspace(0,999,31)[15] is 499.5
and we want the same result as the original codebase.
"""
return tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
).flip(0)
def rebuild(
self: "DPMSolver",
num_inference_steps: int | None,
@ -148,10 +151,10 @@ class DPMSolver(Solver):
(ODEs).
Args:
x: The input data.
predicted_noise: The predicted noise.
step: The current step.
generator: The random number generator.
x: The input tensor to apply the diffusion process to.
predicted_noise: The predicted noise tensor for the current step.
step: The current step of the diffusion process.
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
Returns:
The denoised version of the input data `x`.

View file

@ -2,7 +2,7 @@ import numpy as np
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
class Euler(Solver):
@ -16,6 +16,8 @@ class Euler(Solver):
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
@ -28,6 +30,8 @@ class Euler(Solver):
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule.
@ -40,6 +44,8 @@ class Euler(Solver):
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
@ -54,20 +60,6 @@ class Euler(Solver):
"""The initial noise sigma."""
return self.sigmas.max()
def _generate_timesteps(self) -> Tensor:
"""Generate the timesteps used by the solver.
Note:
We need to use numpy here because:
- numpy.linspace(0,999,31)[15] is 499.49999999999994
- torch.linspace(0,999,31)[15] is 499.5
and we want the same result as the original codebase.
"""
timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0)
return timesteps
def _generate_sigmas(self) -> Tensor:
"""Generate the sigmas used by the solver."""
sigmas = self.noise_std / self.cumulative_scale_factors
@ -88,12 +80,17 @@ class Euler(Solver):
sigma = self.sigmas[step]
return x / ((sigma**2 + 1) ** 0.5)
def __call__(
self,
x: Tensor,
predicted_noise: Tensor,
step: int,
generator: Generator | None = None,
) -> Tensor:
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""Apply one step of the backward diffusion process.
Args:
x: The input tensor to apply the diffusion process to.
predicted_noise: The predicted noise tensor for the current step.
step: The current step of the diffusion process.
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
Returns:
The denoised version of the input data `x`.
"""
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step])

View file

@ -1,8 +1,7 @@
import numpy as np
import torch
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
class LCMSolver(Solver):
@ -20,14 +19,29 @@ class LCMSolver(Solver):
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING,
timesteps_offset: int = 0,
num_orig_steps: int = 50,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
diffusers_mode: bool = False,
device: torch.device | str = "cpu",
dtype: torch.dtype = torch.float32,
):
"""Initializes a new LCM solver.
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
num_orig_steps: The number of inference steps of the emulated DPM solver.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
assert (
num_orig_steps >= num_inference_steps
), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})"
@ -36,22 +50,17 @@ class LCMSolver(Solver):
DPMSolver(
num_inference_steps=num_orig_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
device=device,
dtype=dtype,
)
]
if diffusers_mode:
# Diffusers recomputes the timesteps in LCMScheduler,
# and it does it slightly differently than DPM Solver.
# We provide this option to reproduce Diffusers' output.
k = num_train_timesteps // num_orig_steps
ts = np.asarray(list(range(1, num_orig_steps + 1))) * k - 1
self.dpm.timesteps = torch.tensor(ts, device=device).flip(0)
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
@ -85,12 +94,19 @@ class LCMSolver(Solver):
return self.dpm.timesteps[self.timestep_indices]
def __call__(
self,
x: torch.Tensor,
predicted_noise: torch.Tensor,
step: int,
generator: torch.Generator | None = None,
self, x: torch.Tensor, predicted_noise: torch.Tensor, step: int, generator: torch.Generator | None = None
) -> torch.Tensor:
"""Apply one step of the backward diffusion process.
Args:
x: The input tensor to apply the diffusion process to.
predicted_noise: The predicted noise tensor for the current step.
step: The current step of the diffusion process.
generator: The random number generator to use for sampling noise.
Returns:
The denoised version of the input data `x`.
"""
current_timestep = self.timesteps[step]
scale_factor = self.cumulative_scale_factors[current_timestep]
noise_ratio = self.noise_std[current_timestep]

View file

@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import TypeVar
from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
import numpy as np
from torch import Generator, Tensor, arange, device as Device, dtype as DType, float32, linspace, log, sqrt, tensor
from refiners.fluxion import layers as fl
@ -10,11 +11,11 @@ T = TypeVar("T", bound="Solver")
class NoiseSchedule(str, Enum):
"""An enumeration of noise schedules used to sample the noise schedule.
"""An enumeration of schedules used to sample the noise.
Attributes:
UNIFORM: A uniform noise schedule.
QUADRATIC: A quadratic noise schedule.
QUADRATIC: A quadratic noise schedule. Corresponds to "Stable Diffusion" in [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 1.
KARRAS: See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models, Equation 5](https://arxiv.org/abs/2206.00364)
"""
@ -23,6 +24,26 @@ class NoiseSchedule(str, Enum):
KARRAS = "karras"
class TimestepSpacing(str, Enum):
"""An enumeration of methods to space the timesteps.
See [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
Attributes:
LINSPACE_FLOAT: Sample N steps with linear interpolation, return a floating-point tensor.
LINSPACE_INT: Same as LINSPACE_FLOAT but return an integer tensor with rounded timesteps.
LEADING: Sample N+1 steps, do not include the last timestep (i.e. bad - non-zero SNR). Used in DDIM.
TRAILING: Sample N+1 steps, do not include the first timestep.
TRAILING_ALT: Variant of TRAILING used in DPM.
"""
LINSPACE_FLOAT = "linspace_float"
LINSPACE_INT = "linspace_int"
LEADING = "leading"
TRAILING = "trailing"
TRAILING_ALT = "trailing_alt"
class Solver(fl.Module, ABC):
"""The base class for creating a diffusion model solver.
@ -39,6 +60,8 @@ class Solver(fl.Module, ABC):
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
@ -51,6 +74,8 @@ class Solver(fl.Module, ABC):
Args:
num_inference_steps: The number of inference steps to perform.
num_train_timesteps: The number of timesteps used to train the diffusion process.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
noise_schedule: The noise schedule used to sample the noise schedule.
@ -61,6 +86,8 @@ class Solver(fl.Module, ABC):
super().__init__()
self.num_inference_steps = num_inference_steps
self.num_train_timesteps = num_train_timesteps
self.timesteps_spacing = timesteps_spacing
self.timesteps_offset = timesteps_offset
self.initial_diffusion_rate = initial_diffusion_rate
self.final_diffusion_rate = final_diffusion_rate
self.noise_schedule = noise_schedule
@ -87,14 +114,49 @@ class Solver(fl.Module, ABC):
"""
...
@abstractmethod
def _generate_timesteps(self) -> Tensor:
"""Generate a tensor of timesteps.
@staticmethod
def generate_timesteps(
spacing: TimestepSpacing,
num_inference_steps: int,
num_train_timesteps: int = 1000,
offset: int = 0,
) -> Tensor:
"""Generate a tensor of timesteps according to a given spacing.
Note:
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process.
Args:
spacing: The spacing to use for the timesteps.
num_inference_steps: The number of inference steps to perform.
num_train_timesteps: The number of timesteps used to train the diffusion process.
offset: The offset to use for the timesteps.
"""
...
max_timestep = num_train_timesteps - 1 + offset
match spacing:
case TimestepSpacing.LINSPACE_FLOAT:
return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0)
case TimestepSpacing.LINSPACE_INT:
return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
case TimestepSpacing.LEADING:
step_ratio = num_train_timesteps // num_inference_steps
return (arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0)
case TimestepSpacing.TRAILING:
step_ratio = num_train_timesteps // num_inference_steps
max_timestep = num_train_timesteps - 1 + offset
return arange(max_timestep, offset, -step_ratio)
case TimestepSpacing.TRAILING_ALT:
# We use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# and we want the same result as the original DPM codebase.
np_space = np.linspace(offset, max_timestep, num_inference_steps + 1).round().astype(int)[1:]
return tensor(np_space).flip(0)
def _generate_timesteps(self) -> Tensor:
return self.generate_timesteps(
spacing=self.timesteps_spacing,
num_inference_steps=self.num_inference_steps,
num_train_timesteps=self.num_train_timesteps,
offset=self.timesteps_offset,
)
def add_noise(
self,

View file

@ -2,10 +2,19 @@ from typing import cast
from warnings import warn
import pytest
from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn
from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn, tensor
from refiners.fluxion import manual_seed
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, LCMSolver, NoiseSchedule
from refiners.foundationals.latent_diffusion.solvers import (
DDIM,
DDPM,
DPMSolver,
Euler,
LCMSolver,
NoiseSchedule,
Solver,
TimestepSpacing,
)
def test_ddpm_diffusers():
@ -14,7 +23,6 @@ def test_ddpm_diffusers():
diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
diffusers_scheduler.set_timesteps(1000)
refiners_scheduler = DDPM(num_inference_steps=1000)
assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps)
@ -34,6 +42,7 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
)
diffusers_scheduler.set_timesteps(n_steps)
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 3, 32, 32)
predicted_noise = randn(1, 3, 32, 32)
@ -59,6 +68,7 @@ def test_ddim_diffusers():
)
diffusers_scheduler.set_timesteps(30)
refiners_scheduler = DDIM(num_inference_steps=30)
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32)
@ -85,6 +95,7 @@ def test_euler_diffusers():
)
diffusers_scheduler.set_timesteps(30)
refiners_scheduler = Euler(num_inference_steps=30)
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32)
@ -111,9 +122,7 @@ def test_lcm_diffusers():
diffusers_scheduler = LCMScheduler()
diffusers_scheduler.set_timesteps(4)
refiners_scheduler = LCMSolver(num_inference_steps=4, diffusers_mode=True)
# diffusers_mode means the timesteps are the same
refiners_scheduler = LCMSolver(num_inference_steps=4)
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32)
@ -143,7 +152,7 @@ def test_lcm_diffusers():
assert allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}"
def test_scheduler_remove_noise():
def test_solver_remove_noise():
from diffusers import DDIMScheduler # type: ignore
manual_seed(0)
@ -169,7 +178,7 @@ def test_scheduler_remove_noise():
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
def test_scheduler_device(test_device: Device):
def test_solver_device(test_device: Device):
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
@ -182,8 +191,35 @@ def test_scheduler_device(test_device: Device):
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
def test_scheduler_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
scheduler = DDIM(num_inference_steps=30, device=test_device, noise_schedule=noise_schedule)
assert len(scheduler.scale_factors) == 1000
assert scheduler.scale_factors[0] == 1 - scheduler.initial_diffusion_rate
assert scheduler.scale_factors[-1] == 1 - scheduler.final_diffusion_rate
def test_solver_timestep_spacing():
# Tests we get the results from [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
linspace_int = Solver.generate_timesteps(
spacing=TimestepSpacing.LINSPACE_INT,
num_inference_steps=10,
num_train_timesteps=1000,
offset=1,
)
assert equal(linspace_int, tensor([1000, 889, 778, 667, 556, 445, 334, 223, 112, 1]))
leading = Solver.generate_timesteps(
spacing=TimestepSpacing.LEADING,
num_inference_steps=10,
num_train_timesteps=1000,
offset=1,
)
assert equal(leading, tensor([901, 801, 701, 601, 501, 401, 301, 201, 101, 1]))
trailing = Solver.generate_timesteps(
spacing=TimestepSpacing.TRAILING,
num_inference_steps=10,
num_train_timesteps=1000,
offset=1,
)
assert equal(trailing, tensor([1000, 900, 800, 700, 600, 500, 400, 300, 200, 100]))