refactor solvers to support different timestep spacings

This commit is contained in:
Pierre Chapuis 2024-02-22 12:02:58 +01:00
parent d14c5bd5f8
commit ddc1cf8ca7
8 changed files with 220 additions and 97 deletions

View file

@ -3,6 +3,6 @@ from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.euler import Euler from refiners.foundationals.latent_diffusion.solvers.euler import Euler
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule"] __all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule", "TimestepSpacing"]

View file

@ -1,6 +1,6 @@
from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
class DDIM(Solver): class DDIM(Solver):
@ -13,6 +13,8 @@ class DDIM(Solver):
self, self,
num_inference_steps: int, num_inference_steps: int,
num_train_timesteps: int = 1_000, num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
timesteps_offset: int = 1,
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
@ -25,6 +27,8 @@ class DDIM(Solver):
Args: Args:
num_inference_steps: The number of inference steps. num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps. num_train_timesteps: The number of training timesteps.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate. initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate. final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule. noise_schedule: The noise schedule.
@ -35,6 +39,8 @@ class DDIM(Solver):
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps, num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule, noise_schedule=noise_schedule,
@ -43,16 +49,18 @@ class DDIM(Solver):
dtype=dtype, dtype=dtype,
) )
def _generate_timesteps(self) -> Tensor:
"""
Generates decreasing timesteps with 'leading' spacing and offset of 1
similar to diffusers settings for the DDIM solver in Stable Diffusion 1.5
"""
step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1
return timesteps.flip(0)
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""Apply one step of the backward diffusion process.
Args:
x: The input tensor to apply the diffusion process to.
predicted_noise: The predicted noise tensor for the current step.
step: The current step of the diffusion process.
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
Returns:
The denoised version of the input data `x`.
"""
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
timestep, previous_timestep = ( timestep, previous_timestep = (

View file

@ -1,6 +1,6 @@
from torch import Generator, Tensor, arange, device as Device from torch import Generator, Tensor, device as Device
from refiners.foundationals.latent_diffusion.solvers.solver import Solver from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
class DDPM(Solver): class DDPM(Solver):
@ -17,6 +17,8 @@ class DDPM(Solver):
self, self,
num_inference_steps: int, num_inference_steps: int,
num_train_timesteps: int = 1_000, num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
first_inference_step: int = 0, first_inference_step: int = 0,
@ -27,6 +29,8 @@ class DDPM(Solver):
Args: Args:
num_inference_steps: The number of inference steps. num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps. num_train_timesteps: The number of training timesteps.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate. initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate. final_diffusion_rate: The final diffusion rate.
first_inference_step: The first inference step. first_inference_step: The first inference step.
@ -35,16 +39,13 @@ class DDPM(Solver):
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps, num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
first_inference_step=first_inference_step, first_inference_step=first_inference_step,
device=device, device=device,
) )
def _generate_timesteps(self) -> Tensor:
step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
return timesteps.flip(0)
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
raise NotImplementedError raise NotImplementedError

View file

@ -1,9 +1,8 @@
from collections import deque from collections import deque
import numpy as np
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
class DPMSolver(Solver): class DPMSolver(Solver):
@ -23,6 +22,8 @@ class DPMSolver(Solver):
self, self,
num_inference_steps: int, num_inference_steps: int,
num_train_timesteps: int = 1_000, num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING_ALT,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
last_step_first_order: bool = False, last_step_first_order: bool = False,
@ -31,9 +32,26 @@ class DPMSolver(Solver):
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
): ):
"""Initializes a new DPM solver.
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
last_step_first_order: Use a first-order update for the last step.
noise_schedule: The noise schedule.
first_inference_step: The first inference step.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps, num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule, noise_schedule=noise_schedule,
@ -44,21 +62,6 @@ class DPMSolver(Solver):
self.estimated_data = deque([tensor([])] * 2, maxlen=2) self.estimated_data = deque([tensor([])] * 2, maxlen=2)
self.last_step_first_order = last_step_first_order self.last_step_first_order = last_step_first_order
def _generate_timesteps(self) -> Tensor:
"""Generate the timesteps used by the solver.
Note:
We need to use numpy here because:
- numpy.linspace(0,999,31)[15] is 499.49999999999994
- torch.linspace(0,999,31)[15] is 499.5
and we want the same result as the original codebase.
"""
return tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
).flip(0)
def rebuild( def rebuild(
self: "DPMSolver", self: "DPMSolver",
num_inference_steps: int | None, num_inference_steps: int | None,
@ -148,10 +151,10 @@ class DPMSolver(Solver):
(ODEs). (ODEs).
Args: Args:
x: The input data. x: The input tensor to apply the diffusion process to.
predicted_noise: The predicted noise. predicted_noise: The predicted noise tensor for the current step.
step: The current step. step: The current step of the diffusion process.
generator: The random number generator. generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
Returns: Returns:
The denoised version of the input data `x`. The denoised version of the input data `x`.

View file

@ -2,7 +2,7 @@ import numpy as np
import torch import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
class Euler(Solver): class Euler(Solver):
@ -16,6 +16,8 @@ class Euler(Solver):
self, self,
num_inference_steps: int, num_inference_steps: int,
num_train_timesteps: int = 1_000, num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
@ -28,6 +30,8 @@ class Euler(Solver):
Args: Args:
num_inference_steps: The number of inference steps. num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps. num_train_timesteps: The number of training timesteps.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate. initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate. final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule. noise_schedule: The noise schedule.
@ -40,6 +44,8 @@ class Euler(Solver):
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps, num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule, noise_schedule=noise_schedule,
@ -54,20 +60,6 @@ class Euler(Solver):
"""The initial noise sigma.""" """The initial noise sigma."""
return self.sigmas.max() return self.sigmas.max()
def _generate_timesteps(self) -> Tensor:
"""Generate the timesteps used by the solver.
Note:
We need to use numpy here because:
- numpy.linspace(0,999,31)[15] is 499.49999999999994
- torch.linspace(0,999,31)[15] is 499.5
and we want the same result as the original codebase.
"""
timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0)
return timesteps
def _generate_sigmas(self) -> Tensor: def _generate_sigmas(self) -> Tensor:
"""Generate the sigmas used by the solver.""" """Generate the sigmas used by the solver."""
sigmas = self.noise_std / self.cumulative_scale_factors sigmas = self.noise_std / self.cumulative_scale_factors
@ -88,12 +80,17 @@ class Euler(Solver):
sigma = self.sigmas[step] sigma = self.sigmas[step]
return x / ((sigma**2 + 1) ** 0.5) return x / ((sigma**2 + 1) ** 0.5)
def __call__( def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
self, """Apply one step of the backward diffusion process.
x: Tensor,
predicted_noise: Tensor, Args:
step: int, x: The input tensor to apply the diffusion process to.
generator: Generator | None = None, predicted_noise: The predicted noise tensor for the current step.
) -> Tensor: step: The current step of the diffusion process.
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
Returns:
The denoised version of the input data `x`.
"""
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step]) return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step])

View file

@ -1,8 +1,7 @@
import numpy as np
import torch import torch
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
class LCMSolver(Solver): class LCMSolver(Solver):
@ -20,14 +19,29 @@ class LCMSolver(Solver):
self, self,
num_inference_steps: int, num_inference_steps: int,
num_train_timesteps: int = 1_000, num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING,
timesteps_offset: int = 0,
num_orig_steps: int = 50, num_orig_steps: int = 50,
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
diffusers_mode: bool = False,
device: torch.device | str = "cpu", device: torch.device | str = "cpu",
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
): ):
"""Initializes a new LCM solver.
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
num_orig_steps: The number of inference steps of the emulated DPM solver.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
assert ( assert (
num_orig_steps >= num_inference_steps num_orig_steps >= num_inference_steps
), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})" ), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})"
@ -36,22 +50,17 @@ class LCMSolver(Solver):
DPMSolver( DPMSolver(
num_inference_steps=num_orig_steps, num_inference_steps=num_orig_steps,
num_train_timesteps=num_train_timesteps, num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
] ]
if diffusers_mode:
# Diffusers recomputes the timesteps in LCMScheduler,
# and it does it slightly differently than DPM Solver.
# We provide this option to reproduce Diffusers' output.
k = num_train_timesteps // num_orig_steps
ts = np.asarray(list(range(1, num_orig_steps + 1))) * k - 1
self.dpm.timesteps = torch.tensor(ts, device=device).flip(0)
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps, num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule, noise_schedule=noise_schedule,
@ -85,12 +94,19 @@ class LCMSolver(Solver):
return self.dpm.timesteps[self.timestep_indices] return self.dpm.timesteps[self.timestep_indices]
def __call__( def __call__(
self, self, x: torch.Tensor, predicted_noise: torch.Tensor, step: int, generator: torch.Generator | None = None
x: torch.Tensor,
predicted_noise: torch.Tensor,
step: int,
generator: torch.Generator | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Apply one step of the backward diffusion process.
Args:
x: The input tensor to apply the diffusion process to.
predicted_noise: The predicted noise tensor for the current step.
step: The current step of the diffusion process.
generator: The random number generator to use for sampling noise.
Returns:
The denoised version of the input data `x`.
"""
current_timestep = self.timesteps[step] current_timestep = self.timesteps[step]
scale_factor = self.cumulative_scale_factors[current_timestep] scale_factor = self.cumulative_scale_factors[current_timestep]
noise_ratio = self.noise_std[current_timestep] noise_ratio = self.noise_std[current_timestep]

View file

@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import TypeVar from typing import TypeVar
from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt import numpy as np
from torch import Generator, Tensor, arange, device as Device, dtype as DType, float32, linspace, log, sqrt, tensor
from refiners.fluxion import layers as fl from refiners.fluxion import layers as fl
@ -10,11 +11,11 @@ T = TypeVar("T", bound="Solver")
class NoiseSchedule(str, Enum): class NoiseSchedule(str, Enum):
"""An enumeration of noise schedules used to sample the noise schedule. """An enumeration of schedules used to sample the noise.
Attributes: Attributes:
UNIFORM: A uniform noise schedule. UNIFORM: A uniform noise schedule.
QUADRATIC: A quadratic noise schedule. QUADRATIC: A quadratic noise schedule. Corresponds to "Stable Diffusion" in [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 1.
KARRAS: See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models, Equation 5](https://arxiv.org/abs/2206.00364) KARRAS: See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models, Equation 5](https://arxiv.org/abs/2206.00364)
""" """
@ -23,6 +24,26 @@ class NoiseSchedule(str, Enum):
KARRAS = "karras" KARRAS = "karras"
class TimestepSpacing(str, Enum):
"""An enumeration of methods to space the timesteps.
See [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
Attributes:
LINSPACE_FLOAT: Sample N steps with linear interpolation, return a floating-point tensor.
LINSPACE_INT: Same as LINSPACE_FLOAT but return an integer tensor with rounded timesteps.
LEADING: Sample N+1 steps, do not include the last timestep (i.e. bad - non-zero SNR). Used in DDIM.
TRAILING: Sample N+1 steps, do not include the first timestep.
TRAILING_ALT: Variant of TRAILING used in DPM.
"""
LINSPACE_FLOAT = "linspace_float"
LINSPACE_INT = "linspace_int"
LEADING = "leading"
TRAILING = "trailing"
TRAILING_ALT = "trailing_alt"
class Solver(fl.Module, ABC): class Solver(fl.Module, ABC):
"""The base class for creating a diffusion model solver. """The base class for creating a diffusion model solver.
@ -39,6 +60,8 @@ class Solver(fl.Module, ABC):
self, self,
num_inference_steps: int, num_inference_steps: int,
num_train_timesteps: int = 1_000, num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
@ -51,6 +74,8 @@ class Solver(fl.Module, ABC):
Args: Args:
num_inference_steps: The number of inference steps to perform. num_inference_steps: The number of inference steps to perform.
num_train_timesteps: The number of timesteps used to train the diffusion process. num_train_timesteps: The number of timesteps used to train the diffusion process.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule. initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
final_diffusion_rate: The final diffusion rate used to sample the noise schedule. final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
noise_schedule: The noise schedule used to sample the noise schedule. noise_schedule: The noise schedule used to sample the noise schedule.
@ -61,6 +86,8 @@ class Solver(fl.Module, ABC):
super().__init__() super().__init__()
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.num_train_timesteps = num_train_timesteps self.num_train_timesteps = num_train_timesteps
self.timesteps_spacing = timesteps_spacing
self.timesteps_offset = timesteps_offset
self.initial_diffusion_rate = initial_diffusion_rate self.initial_diffusion_rate = initial_diffusion_rate
self.final_diffusion_rate = final_diffusion_rate self.final_diffusion_rate = final_diffusion_rate
self.noise_schedule = noise_schedule self.noise_schedule = noise_schedule
@ -87,14 +114,49 @@ class Solver(fl.Module, ABC):
""" """
... ...
@abstractmethod @staticmethod
def _generate_timesteps(self) -> Tensor: def generate_timesteps(
"""Generate a tensor of timesteps. spacing: TimestepSpacing,
num_inference_steps: int,
num_train_timesteps: int = 1000,
offset: int = 0,
) -> Tensor:
"""Generate a tensor of timesteps according to a given spacing.
Note: Args:
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process. spacing: The spacing to use for the timesteps.
num_inference_steps: The number of inference steps to perform.
num_train_timesteps: The number of timesteps used to train the diffusion process.
offset: The offset to use for the timesteps.
""" """
... max_timestep = num_train_timesteps - 1 + offset
match spacing:
case TimestepSpacing.LINSPACE_FLOAT:
return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0)
case TimestepSpacing.LINSPACE_INT:
return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
case TimestepSpacing.LEADING:
step_ratio = num_train_timesteps // num_inference_steps
return (arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0)
case TimestepSpacing.TRAILING:
step_ratio = num_train_timesteps // num_inference_steps
max_timestep = num_train_timesteps - 1 + offset
return arange(max_timestep, offset, -step_ratio)
case TimestepSpacing.TRAILING_ALT:
# We use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# and we want the same result as the original DPM codebase.
np_space = np.linspace(offset, max_timestep, num_inference_steps + 1).round().astype(int)[1:]
return tensor(np_space).flip(0)
def _generate_timesteps(self) -> Tensor:
return self.generate_timesteps(
spacing=self.timesteps_spacing,
num_inference_steps=self.num_inference_steps,
num_train_timesteps=self.num_train_timesteps,
offset=self.timesteps_offset,
)
def add_noise( def add_noise(
self, self,

View file

@ -2,10 +2,19 @@ from typing import cast
from warnings import warn from warnings import warn
import pytest import pytest
from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn, tensor
from refiners.fluxion import manual_seed from refiners.fluxion import manual_seed
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, LCMSolver, NoiseSchedule from refiners.foundationals.latent_diffusion.solvers import (
DDIM,
DDPM,
DPMSolver,
Euler,
LCMSolver,
NoiseSchedule,
Solver,
TimestepSpacing,
)
def test_ddpm_diffusers(): def test_ddpm_diffusers():
@ -14,7 +23,6 @@ def test_ddpm_diffusers():
diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012) diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
diffusers_scheduler.set_timesteps(1000) diffusers_scheduler.set_timesteps(1000)
refiners_scheduler = DDPM(num_inference_steps=1000) refiners_scheduler = DDPM(num_inference_steps=1000)
assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps) assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps)
@ -34,6 +42,7 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
) )
diffusers_scheduler.set_timesteps(n_steps) diffusers_scheduler.set_timesteps(n_steps)
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order) refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 3, 32, 32) sample = randn(1, 3, 32, 32)
predicted_noise = randn(1, 3, 32, 32) predicted_noise = randn(1, 3, 32, 32)
@ -59,6 +68,7 @@ def test_ddim_diffusers():
) )
diffusers_scheduler.set_timesteps(30) diffusers_scheduler.set_timesteps(30)
refiners_scheduler = DDIM(num_inference_steps=30) refiners_scheduler = DDIM(num_inference_steps=30)
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32) sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32) predicted_noise = randn(1, 4, 32, 32)
@ -85,6 +95,7 @@ def test_euler_diffusers():
) )
diffusers_scheduler.set_timesteps(30) diffusers_scheduler.set_timesteps(30)
refiners_scheduler = Euler(num_inference_steps=30) refiners_scheduler = Euler(num_inference_steps=30)
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32) sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32) predicted_noise = randn(1, 4, 32, 32)
@ -111,9 +122,7 @@ def test_lcm_diffusers():
diffusers_scheduler = LCMScheduler() diffusers_scheduler = LCMScheduler()
diffusers_scheduler.set_timesteps(4) diffusers_scheduler.set_timesteps(4)
refiners_scheduler = LCMSolver(num_inference_steps=4, diffusers_mode=True) refiners_scheduler = LCMSolver(num_inference_steps=4)
# diffusers_mode means the timesteps are the same
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps) assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32) sample = randn(1, 4, 32, 32)
@ -143,7 +152,7 @@ def test_lcm_diffusers():
assert allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}" assert allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}"
def test_scheduler_remove_noise(): def test_solver_remove_noise():
from diffusers import DDIMScheduler # type: ignore from diffusers import DDIMScheduler # type: ignore
manual_seed(0) manual_seed(0)
@ -169,7 +178,7 @@ def test_scheduler_remove_noise():
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}" assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
def test_scheduler_device(test_device: Device): def test_solver_device(test_device: Device):
if test_device.type == "cpu": if test_device.type == "cpu":
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
pytest.skip() pytest.skip()
@ -182,8 +191,35 @@ def test_scheduler_device(test_device: Device):
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS]) @pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
def test_scheduler_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device): def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
scheduler = DDIM(num_inference_steps=30, device=test_device, noise_schedule=noise_schedule) scheduler = DDIM(num_inference_steps=30, device=test_device, noise_schedule=noise_schedule)
assert len(scheduler.scale_factors) == 1000 assert len(scheduler.scale_factors) == 1000
assert scheduler.scale_factors[0] == 1 - scheduler.initial_diffusion_rate assert scheduler.scale_factors[0] == 1 - scheduler.initial_diffusion_rate
assert scheduler.scale_factors[-1] == 1 - scheduler.final_diffusion_rate assert scheduler.scale_factors[-1] == 1 - scheduler.final_diffusion_rate
def test_solver_timestep_spacing():
# Tests we get the results from [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
linspace_int = Solver.generate_timesteps(
spacing=TimestepSpacing.LINSPACE_INT,
num_inference_steps=10,
num_train_timesteps=1000,
offset=1,
)
assert equal(linspace_int, tensor([1000, 889, 778, 667, 556, 445, 334, 223, 112, 1]))
leading = Solver.generate_timesteps(
spacing=TimestepSpacing.LEADING,
num_inference_steps=10,
num_train_timesteps=1000,
offset=1,
)
assert equal(leading, tensor([901, 801, 701, 601, 501, 401, 301, 201, 101, 1]))
trailing = Solver.generate_timesteps(
spacing=TimestepSpacing.TRAILING,
num_inference_steps=10,
num_train_timesteps=1000,
offset=1,
)
assert equal(trailing, tensor([1000, 900, 800, 700, 600, 500, 400, 300, 200, 100]))