From bf0ba58541a718c4cd15aa748a806b72c4ee90bb Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Thu, 22 Feb 2024 15:16:22 +0100 Subject: [PATCH] refactor solver params, add sample prediction type --- .../latent_diffusion/solvers/__init__.py | 21 ++- .../latent_diffusion/solvers/ddim.py | 43 +++--- .../latent_diffusion/solvers/ddpm.py | 40 ++--- .../latent_diffusion/solvers/dpm.py | 57 ++++--- .../latent_diffusion/solvers/euler.py | 44 +++--- .../latent_diffusion/solvers/lcm.py | 50 +++--- .../latent_diffusion/solvers/solver.py | 142 ++++++++++++------ tests/e2e/test_diffusion.py | 4 +- .../latent_diffusion/test_solvers.py | 26 +++- 9 files changed, 260 insertions(+), 167 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/solvers/__init__.py b/src/refiners/foundationals/latent_diffusion/solvers/__init__.py index 475d959..7f904a5 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/__init__.py @@ -3,6 +3,23 @@ from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver from refiners.foundationals.latent_diffusion.solvers.euler import Euler from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver -from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing +from refiners.foundationals.latent_diffusion.solvers.solver import ( + ModelPredictionType, + NoiseSchedule, + Solver, + SolverParams, + TimestepSpacing, +) -__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule", "TimestepSpacing"] +__all__ = [ + "Solver", + "SolverParams", + "DPMSolver", + "DDPM", + "DDIM", + "Euler", + "LCMSolver", + "ModelPredictionType", + "NoiseSchedule", + "TimestepSpacing", +] diff --git a/src/refiners/foundationals/latent_diffusion/solvers/ddim.py b/src/refiners/foundationals/latent_diffusion/solvers/ddim.py index 9ea146e..d4d5a81 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/ddim.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/ddim.py @@ -1,6 +1,13 @@ +import dataclasses + from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor -from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing +from refiners.foundationals.latent_diffusion.solvers.solver import ( + ModelPredictionType, + Solver, + SolverParams, + TimestepSpacing, +) class DDIM(Solver): @@ -9,42 +16,36 @@ class DDIM(Solver): See [[arXiv:2010.02502] Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) for more details. """ + default_params = dataclasses.replace( + Solver.default_params, + timesteps_spacing=TimestepSpacing.LEADING, + timesteps_offset=1, + ) + def __init__( self, num_inference_steps: int, - num_train_timesteps: int = 1_000, - timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING, - timesteps_offset: int = 1, - initial_diffusion_rate: float = 8.5e-4, - final_diffusion_rate: float = 1.2e-2, - noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, first_inference_step: int = 0, + params: SolverParams | None = None, device: Device | str = "cpu", dtype: Dtype = float32, ) -> None: """Initializes a new DDIM solver. Args: - num_inference_steps: The number of inference steps. - num_train_timesteps: The number of training timesteps. - timesteps_spacing: The spacing to use for the timesteps. - timesteps_offset: The offset to use for the timesteps. - initial_diffusion_rate: The initial diffusion rate. - final_diffusion_rate: The final diffusion rate. - noise_schedule: The noise schedule. - first_inference_step: The first inference step. + num_inference_steps: The number of inference steps to perform. + first_inference_step: The first inference step to perform. + params: The common parameters for solvers. device: The PyTorch device to use. dtype: The PyTorch data type to use. """ + if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None): + raise NotImplementedError + super().__init__( num_inference_steps=num_inference_steps, - num_train_timesteps=num_train_timesteps, - timesteps_spacing=timesteps_spacing, - timesteps_offset=timesteps_offset, - initial_diffusion_rate=initial_diffusion_rate, - final_diffusion_rate=final_diffusion_rate, - noise_schedule=noise_schedule, first_inference_step=first_inference_step, + params=params, device=device, dtype=dtype, ) diff --git a/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py b/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py index cd8547e..442a63a 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py @@ -1,6 +1,13 @@ +import dataclasses + from torch import Generator, Tensor, device as Device -from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing +from refiners.foundationals.latent_diffusion.solvers.solver import ( + ModelPredictionType, + Solver, + SolverParams, + TimestepSpacing, +) class DDPM(Solver): @@ -13,37 +20,34 @@ class DDPM(Solver): See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) for more details. """ + default_params = dataclasses.replace( + Solver.default_params, + timesteps_spacing=TimestepSpacing.LEADING, + ) + def __init__( self, num_inference_steps: int, - num_train_timesteps: int = 1_000, - timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING, - timesteps_offset: int = 0, - initial_diffusion_rate: float = 8.5e-4, - final_diffusion_rate: float = 1.2e-2, first_inference_step: int = 0, + params: SolverParams | None = None, device: Device | str = "cpu", ) -> None: """Initializes a new DDPM solver. Args: - num_inference_steps: The number of inference steps. - num_train_timesteps: The number of training timesteps. - timesteps_spacing: The spacing to use for the timesteps. - timesteps_offset: The offset to use for the timesteps. - initial_diffusion_rate: The initial diffusion rate. - final_diffusion_rate: The final diffusion rate. - first_inference_step: The first inference step. + num_inference_steps: The number of inference steps to perform. + first_inference_step: The first inference step to perform. + params: The common parameters for solvers. device: The PyTorch device to use. """ + + if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None): + raise NotImplementedError + super().__init__( num_inference_steps=num_inference_steps, - num_train_timesteps=num_train_timesteps, - timesteps_spacing=timesteps_spacing, - timesteps_offset=timesteps_offset, - initial_diffusion_rate=initial_diffusion_rate, - final_diffusion_rate=final_diffusion_rate, first_inference_step=first_inference_step, + params=params, device=device, ) diff --git a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py index 5ef3bfd..15a953e 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py @@ -1,8 +1,15 @@ +import dataclasses from collections import deque +import numpy as np from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor -from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing +from refiners.foundationals.latent_diffusion.solvers.solver import ( + ModelPredictionType, + Solver, + SolverParams, + TimestepSpacing, +) class DPMSolver(Solver): @@ -18,44 +25,37 @@ class DPMSolver(Solver): for the last step of the diffusion. """ + default_params = dataclasses.replace( + Solver.default_params, + timesteps_spacing=TimestepSpacing.CUSTOM, + ) + def __init__( self, num_inference_steps: int, - num_train_timesteps: int = 1_000, - timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING_ALT, - timesteps_offset: int = 0, - initial_diffusion_rate: float = 8.5e-4, - final_diffusion_rate: float = 1.2e-2, - last_step_first_order: bool = False, - noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, first_inference_step: int = 0, + params: SolverParams | None = None, + last_step_first_order: bool = False, device: Device | str = "cpu", dtype: Dtype = float32, ): """Initializes a new DPM solver. Args: - num_inference_steps: The number of inference steps. - num_train_timesteps: The number of training timesteps. - timesteps_spacing: The spacing to use for the timesteps. - timesteps_offset: The offset to use for the timesteps. - initial_diffusion_rate: The initial diffusion rate. - final_diffusion_rate: The final diffusion rate. + num_inference_steps: The number of inference steps to perform. + first_inference_step: The first inference step to perform. + params: The common parameters for solvers. last_step_first_order: Use a first-order update for the last step. - noise_schedule: The noise schedule. - first_inference_step: The first inference step. device: The PyTorch device to use. dtype: The PyTorch data type to use. """ + if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None): + raise NotImplementedError + super().__init__( num_inference_steps=num_inference_steps, - num_train_timesteps=num_train_timesteps, - timesteps_spacing=timesteps_spacing, - timesteps_offset=timesteps_offset, - initial_diffusion_rate=initial_diffusion_rate, - final_diffusion_rate=final_diffusion_rate, - noise_schedule=noise_schedule, first_inference_step=first_inference_step, + params=params, device=device, dtype=dtype, ) @@ -80,6 +80,19 @@ class DPMSolver(Solver): r.last_step_first_order = self.last_step_first_order return r + def _generate_timesteps(self) -> Tensor: + if self.params.timesteps_spacing != TimestepSpacing.CUSTOM: + return super()._generate_timesteps() + + # We use numpy here because: + # numpy.linspace(0,999,31)[15] is 499.49999999999994 + # torch.linspace(0,999,31)[15] is 499.5 + # and we want the same result as the original DPM codebase. + offset = self.params.timesteps_offset + max_timestep = self.params.num_train_timesteps - 1 + offset + np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:] + return tensor(np_space).flip(0) + def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor: """Applies a first-order backward Euler update to the input data `x`. diff --git a/src/refiners/foundationals/latent_diffusion/solvers/euler.py b/src/refiners/foundationals/latent_diffusion/solvers/euler.py index fc2ef73..3c062ec 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/euler.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/euler.py @@ -2,7 +2,12 @@ import numpy as np import torch from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor -from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing +from refiners.foundationals.latent_diffusion.solvers.solver import ( + ModelPredictionType, + NoiseSchedule, + Solver, + SolverParams, +) class Euler(Solver): @@ -15,41 +20,27 @@ class Euler(Solver): def __init__( self, num_inference_steps: int, - num_train_timesteps: int = 1_000, - timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT, - timesteps_offset: int = 0, - initial_diffusion_rate: float = 8.5e-4, - final_diffusion_rate: float = 1.2e-2, - noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, first_inference_step: int = 0, + params: SolverParams | None = None, device: Device | str = "cpu", dtype: Dtype = float32, ): """Initializes a new Euler solver. Args: - num_inference_steps: The number of inference steps. - num_train_timesteps: The number of training timesteps. - timesteps_spacing: The spacing to use for the timesteps. - timesteps_offset: The offset to use for the timesteps. - initial_diffusion_rate: The initial diffusion rate. - final_diffusion_rate: The final diffusion rate. - noise_schedule: The noise schedule. - first_inference_step: The first inference step. + num_inference_steps: The number of inference steps to perform. + first_inference_step: The first inference step to perform. + params: The common parameters for solvers. device: The PyTorch device to use. dtype: The PyTorch data type to use. """ - if noise_schedule != NoiseSchedule.QUADRATIC: + if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None): raise NotImplementedError + super().__init__( num_inference_steps=num_inference_steps, - num_train_timesteps=num_train_timesteps, - timesteps_spacing=timesteps_spacing, - timesteps_offset=timesteps_offset, - initial_diffusion_rate=initial_diffusion_rate, - final_diffusion_rate=final_diffusion_rate, - noise_schedule=noise_schedule, first_inference_step=first_inference_step, + params=params, device=device, dtype=dtype, ) @@ -85,7 +76,7 @@ class Euler(Solver): Args: x: The input tensor to apply the diffusion process to. - predicted_noise: The predicted noise tensor for the current step. + predicted_noise: The predicted noise tensor for the current step (or x0 if the prediction type is SAMPLE). step: The current step of the diffusion process. generator: The random number generator to use for sampling noise (ignored, this solver is deterministic). @@ -93,4 +84,11 @@ class Euler(Solver): The denoised version of the input data `x`. """ assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" + + if self.params.model_prediction_type == ModelPredictionType.SAMPLE: + x0 = predicted_noise # the model does not actually predict the noise but x0 + ratio = self.sigmas[step + 1] / self.sigmas[step] + return ratio * x + (1 - ratio) * x0 + + assert self.params.model_prediction_type == ModelPredictionType.NOISE return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step]) diff --git a/src/refiners/foundationals/latent_diffusion/solvers/lcm.py b/src/refiners/foundationals/latent_diffusion/solvers/lcm.py index 0ffd581..c7087d2 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/lcm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/lcm.py @@ -1,7 +1,14 @@ +import dataclasses + import torch from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver -from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing +from refiners.foundationals.latent_diffusion.solvers.solver import ( + ModelPredictionType, + Solver, + SolverParams, + TimestepSpacing, +) class LCMSolver(Solver): @@ -15,55 +22,52 @@ class LCMSolver(Solver): for details. """ + # The spacing parameter is actually the spacing of the underlying DPM solver. + default_params = dataclasses.replace(Solver.default_params, timesteps_spacing=TimestepSpacing.TRAILING) + def __init__( self, num_inference_steps: int, - num_train_timesteps: int = 1_000, - timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING, - timesteps_offset: int = 0, + first_inference_step: int = 0, + params: SolverParams | None = None, num_orig_steps: int = 50, - initial_diffusion_rate: float = 8.5e-4, - final_diffusion_rate: float = 1.2e-2, - noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, device: torch.device | str = "cpu", dtype: torch.dtype = torch.float32, ): """Initializes a new LCM solver. Args: - num_inference_steps: The number of inference steps. - num_train_timesteps: The number of training timesteps. - timesteps_spacing: The spacing to use for the timesteps. - timesteps_offset: The offset to use for the timesteps. + num_inference_steps: The number of inference steps to perform. + first_inference_step: The first inference step to perform. + params: The common parameters for solvers. num_orig_steps: The number of inference steps of the emulated DPM solver. - initial_diffusion_rate: The initial diffusion rate. - final_diffusion_rate: The final diffusion rate. - noise_schedule: The noise schedule. device: The PyTorch device to use. dtype: The PyTorch data type to use. """ + assert ( num_orig_steps >= num_inference_steps ), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})" + params = self.resolve_params(params) + if params.model_prediction_type != ModelPredictionType.NOISE: + raise NotImplementedError + self._dpm = [ DPMSolver( num_inference_steps=num_orig_steps, - num_train_timesteps=num_train_timesteps, - timesteps_spacing=timesteps_spacing, + params=SolverParams( + num_train_timesteps=params.num_train_timesteps, + timesteps_spacing=params.timesteps_spacing, + ), device=device, dtype=dtype, ) ] - super().__init__( num_inference_steps=num_inference_steps, - num_train_timesteps=num_train_timesteps, - timesteps_spacing=timesteps_spacing, - timesteps_offset=timesteps_offset, - initial_diffusion_rate=initial_diffusion_rate, - final_diffusion_rate=final_diffusion_rate, - noise_schedule=noise_schedule, + first_inference_step=first_inference_step, + params=params, device=device, dtype=dtype, ) diff --git a/src/refiners/foundationals/latent_diffusion/solvers/solver.py b/src/refiners/foundationals/latent_diffusion/solvers/solver.py index 32ef32d..ad9601c 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/solver.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/solver.py @@ -1,3 +1,4 @@ +import dataclasses from abc import ABC, abstractmethod from enum import Enum from typing import TypeVar @@ -30,18 +31,64 @@ class TimestepSpacing(str, Enum): See [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2. Attributes: - LINSPACE_FLOAT: Sample N steps with linear interpolation, return a floating-point tensor. - LINSPACE_INT: Same as LINSPACE_FLOAT but return an integer tensor with rounded timesteps. + LINSPACE: Sample N steps with linear interpolation, return a floating-point tensor. + LINSPACE_ROUNDED: Same as LINSPACE but return an integer tensor with rounded timesteps. LEADING: Sample N+1 steps, do not include the last timestep (i.e. bad - non-zero SNR). Used in DDIM. TRAILING: Sample N+1 steps, do not include the first timestep. - TRAILING_ALT: Variant of TRAILING used in DPM. + CUSTOM: Use custom timespacing in solver (override `_generate_timesteps`, see DPM). """ - LINSPACE_FLOAT = "linspace_float" - LINSPACE_INT = "linspace_int" + LINSPACE = "linspace" + LINSPACE_ROUNDED = "linspace_rounded" LEADING = "leading" TRAILING = "trailing" - TRAILING_ALT = "trailing_alt" + CUSTOM = "custom" + + +class ModelPredictionType(str, Enum): + """An enumeration of possible outputs of the model. + + Attributes: + NOISE: The model predicts the noise (epsilon). + SAMPLE: The model predicts the denoised sample (x0). + """ + + NOISE = "noise" + SAMPLE = "sample" + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class SolverParams: + """Common parameters for solvers. + + Args: + num_train_timesteps: The number of timesteps used to train the diffusion process. + timesteps_spacing: The spacing to use for the timesteps. + timesteps_offset: The offset to use for the timesteps. + initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule. + final_diffusion_rate: The final diffusion rate used to sample the noise schedule. + noise_schedule: The noise schedule used to sample the noise schedule. + model_prediction_type: Defines what the model predicts. + """ + + num_train_timesteps: int | None = None + timesteps_spacing: TimestepSpacing | None = None + timesteps_offset: int | None = None + initial_diffusion_rate: float | None = None + final_diffusion_rate: float | None = None + noise_schedule: NoiseSchedule | None = None + model_prediction_type: ModelPredictionType | None = None + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class ResolvedSolverParams(SolverParams): + num_train_timesteps: int + timesteps_spacing: TimestepSpacing + timesteps_offset: int + initial_diffusion_rate: float + final_diffusion_rate: float + noise_schedule: NoiseSchedule + model_prediction_type: ModelPredictionType class Solver(fl.Module, ABC): @@ -55,17 +102,23 @@ class Solver(fl.Module, ABC): """ timesteps: Tensor + params: ResolvedSolverParams + + default_params = ResolvedSolverParams( + num_train_timesteps=1000, + timesteps_spacing=TimestepSpacing.LINSPACE, + timesteps_offset=0, + initial_diffusion_rate=8.5e-4, + final_diffusion_rate=1.2e-2, + noise_schedule=NoiseSchedule.QUADRATIC, + model_prediction_type=ModelPredictionType.NOISE, + ) def __init__( self, num_inference_steps: int, - num_train_timesteps: int = 1_000, - timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT, - timesteps_offset: int = 0, - initial_diffusion_rate: float = 8.5e-4, - final_diffusion_rate: float = 1.2e-2, - noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, first_inference_step: int = 0, + params: SolverParams | None = None, device: Device | str = "cpu", dtype: DType = float32, ) -> None: @@ -73,32 +126,33 @@ class Solver(fl.Module, ABC): Args: num_inference_steps: The number of inference steps to perform. - num_train_timesteps: The number of timesteps used to train the diffusion process. - timesteps_spacing: The spacing to use for the timesteps. - timesteps_offset: The offset to use for the timesteps. - initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule. - final_diffusion_rate: The final diffusion rate used to sample the noise schedule. - noise_schedule: The noise schedule used to sample the noise schedule. first_inference_step: The first inference step to perform. + params: The common parameters for solvers. device: The PyTorch device to use for the solver's tensors. dtype: The PyTorch data type to use for the solver's tensors. """ super().__init__() + self.num_inference_steps = num_inference_steps - self.num_train_timesteps = num_train_timesteps - self.timesteps_spacing = timesteps_spacing - self.timesteps_offset = timesteps_offset - self.initial_diffusion_rate = initial_diffusion_rate - self.final_diffusion_rate = final_diffusion_rate - self.noise_schedule = noise_schedule self.first_inference_step = first_inference_step + self.params = self.resolve_params(params) + self.scale_factors = self.sample_noise_schedule() self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0)) self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0)) self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std) self.timesteps = self._generate_timesteps() + self.to(device=device, dtype=dtype) + def resolve_params(self, params: SolverParams | None) -> ResolvedSolverParams: + if params is None: + return dataclasses.replace(self.default_params) + return dataclasses.replace( + self.default_params, + **{k: v for k, v in dataclasses.asdict(params).items() if v is not None}, + ) + @abstractmethod def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: """Apply a step of the diffusion process using the Solver. @@ -131,9 +185,9 @@ class Solver(fl.Module, ABC): """ max_timestep = num_train_timesteps - 1 + offset match spacing: - case TimestepSpacing.LINSPACE_FLOAT: + case TimestepSpacing.LINSPACE: return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0) - case TimestepSpacing.LINSPACE_INT: + case TimestepSpacing.LINSPACE_ROUNDED: return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0) case TimestepSpacing.LEADING: step_ratio = num_train_timesteps // num_inference_steps @@ -142,20 +196,15 @@ class Solver(fl.Module, ABC): step_ratio = num_train_timesteps // num_inference_steps max_timestep = num_train_timesteps - 1 + offset return arange(max_timestep, offset, -step_ratio) - case TimestepSpacing.TRAILING_ALT: - # We use numpy here because: - # numpy.linspace(0,999,31)[15] is 499.49999999999994 - # torch.linspace(0,999,31)[15] is 499.5 - # and we want the same result as the original DPM codebase. - np_space = np.linspace(offset, max_timestep, num_inference_steps + 1).round().astype(int)[1:] - return tensor(np_space).flip(0) + case TimestepSpacing.CUSTOM: + raise RuntimeError("generate_timesteps called with custom spacing") def _generate_timesteps(self) -> Tensor: return self.generate_timesteps( - spacing=self.timesteps_spacing, + spacing=self.params.timesteps_spacing, num_inference_steps=self.num_inference_steps, - num_train_timesteps=self.num_train_timesteps, - offset=self.timesteps_offset, + num_train_timesteps=self.params.num_train_timesteps, + offset=self.params.timesteps_offset, ) def add_noise( @@ -239,15 +288,10 @@ class Solver(fl.Module, ABC): Returns: A new solver instance with the specified parameters. """ - num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps - first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step return self.__class__( - num_inference_steps=num_inference_steps, - num_train_timesteps=self.num_train_timesteps, - initial_diffusion_rate=self.initial_diffusion_rate, - final_diffusion_rate=self.final_diffusion_rate, - noise_schedule=self.noise_schedule, - first_inference_step=first_inference_step, + num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps, + first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step, + params=dataclasses.replace(self.params), device=self.device, dtype=self.dtype, ) @@ -282,9 +326,9 @@ class Solver(fl.Module, ABC): """ return ( linspace( - start=self.initial_diffusion_rate ** (1 / power), - end=self.final_diffusion_rate ** (1 / power), - steps=self.num_train_timesteps, + start=self.params.initial_diffusion_rate ** (1 / power), + end=self.params.final_diffusion_rate ** (1 / power), + steps=self.params.num_train_timesteps, ) ** power ) @@ -295,7 +339,7 @@ class Solver(fl.Module, ABC): Returns: A tensor representing the noise schedule. """ - match self.noise_schedule: + match self.params.noise_schedule: case "uniform": return 1 - self.sample_power_distribution(1) case "quadratic": @@ -303,7 +347,7 @@ class Solver(fl.Module, ABC): case "karras": return 1 - self.sample_power_distribution(7) case _: - raise ValueError(f"Unknown noise schedule: {self.noise_schedule}") + raise ValueError(f"Unknown noise schedule: {self.params.noise_schedule}") def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver": """Move the solver to the specified device and data type. diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 0573526..3b1f9f5 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -27,7 +27,7 @@ from refiners.foundationals.latent_diffusion.lora import SDLoraManager from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.latent_diffusion.restart import Restart -from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule +from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule, SolverParams from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter @@ -600,7 +600,7 @@ def sd15_ddim_karras( warn("not running on CPU, skipping") pytest.skip() - ddim_solver = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS) + ddim_solver = DDIM(num_inference_steps=20, params=SolverParams(noise_schedule=NoiseSchedule.KARRAS)) sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device) sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) diff --git a/tests/foundationals/latent_diffusion/test_solvers.py b/tests/foundationals/latent_diffusion/test_solvers.py index 109494a..2065806 100644 --- a/tests/foundationals/latent_diffusion/test_solvers.py +++ b/tests/foundationals/latent_diffusion/test_solvers.py @@ -11,8 +11,10 @@ from refiners.foundationals.latent_diffusion.solvers import ( DPMSolver, Euler, LCMSolver, + ModelPredictionType, NoiseSchedule, Solver, + SolverParams, TimestepSpacing, ) @@ -41,7 +43,10 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool): final_sigmas_type="sigma_min", # default before Diffusers 0.26.0 ) diffusers_scheduler.set_timesteps(n_steps) - refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order) + refiners_scheduler = DPMSolver( + num_inference_steps=n_steps, + last_step_first_order=last_step_first_order, + ) assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps) sample = randn(1, 3, 32, 32) @@ -80,10 +85,12 @@ def test_ddim_diffusers(): assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}" -def test_euler_diffusers(): +@pytest.mark.parametrize("model_prediction_type", [ModelPredictionType.NOISE, ModelPredictionType.SAMPLE]) +def test_euler_diffusers(model_prediction_type: ModelPredictionType): from diffusers import EulerDiscreteScheduler # type: ignore manual_seed(0) + diffusers_prediction_type = "epsilon" if model_prediction_type == ModelPredictionType.NOISE else "sample" diffusers_scheduler = EulerDiscreteScheduler( beta_end=0.012, beta_schedule="scaled_linear", @@ -92,9 +99,10 @@ def test_euler_diffusers(): steps_offset=1, timestep_spacing="linspace", use_karras_sigmas=False, + prediction_type=diffusers_prediction_type, ) diffusers_scheduler.set_timesteps(30) - refiners_scheduler = Euler(num_inference_steps=30) + refiners_scheduler = Euler(num_inference_steps=30, params=SolverParams(model_prediction_type=model_prediction_type)) assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps) sample = randn(1, 4, 32, 32) @@ -192,16 +200,20 @@ def test_solver_device(test_device: Device): @pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS]) def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device): - scheduler = DDIM(num_inference_steps=30, device=test_device, noise_schedule=noise_schedule) + scheduler = DDIM( + num_inference_steps=30, + params=SolverParams(noise_schedule=noise_schedule), + device=test_device, + ) assert len(scheduler.scale_factors) == 1000 - assert scheduler.scale_factors[0] == 1 - scheduler.initial_diffusion_rate - assert scheduler.scale_factors[-1] == 1 - scheduler.final_diffusion_rate + assert scheduler.scale_factors[0] == 1 - scheduler.params.initial_diffusion_rate + assert scheduler.scale_factors[-1] == 1 - scheduler.params.final_diffusion_rate def test_solver_timestep_spacing(): # Tests we get the results from [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2. linspace_int = Solver.generate_timesteps( - spacing=TimestepSpacing.LINSPACE_INT, + spacing=TimestepSpacing.LINSPACE_ROUNDED, num_inference_steps=10, num_train_timesteps=1000, offset=1,