refactor solver params, add sample prediction type

This commit is contained in:
Pierre Chapuis 2024-02-22 15:16:22 +01:00
parent ddc1cf8ca7
commit bf0ba58541
9 changed files with 260 additions and 167 deletions

View file

@ -3,6 +3,23 @@ from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.euler import Euler from refiners.foundationals.latent_diffusion.solvers.euler import Euler
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing from refiners.foundationals.latent_diffusion.solvers.solver import (
ModelPredictionType,
NoiseSchedule,
Solver,
SolverParams,
TimestepSpacing,
)
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule", "TimestepSpacing"] __all__ = [
"Solver",
"SolverParams",
"DPMSolver",
"DDPM",
"DDIM",
"Euler",
"LCMSolver",
"ModelPredictionType",
"NoiseSchedule",
"TimestepSpacing",
]

View file

@ -1,6 +1,13 @@
import dataclasses
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing from refiners.foundationals.latent_diffusion.solvers.solver import (
ModelPredictionType,
Solver,
SolverParams,
TimestepSpacing,
)
class DDIM(Solver): class DDIM(Solver):
@ -9,42 +16,36 @@ class DDIM(Solver):
See [[arXiv:2010.02502] Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) for more details. See [[arXiv:2010.02502] Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) for more details.
""" """
default_params = dataclasses.replace(
Solver.default_params,
timesteps_spacing=TimestepSpacing.LEADING,
timesteps_offset=1,
)
def __init__( def __init__(
self, self,
num_inference_steps: int, num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
timesteps_offset: int = 1,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0, first_inference_step: int = 0,
params: SolverParams | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
) -> None: ) -> None:
"""Initializes a new DDIM solver. """Initializes a new DDIM solver.
Args: Args:
num_inference_steps: The number of inference steps. num_inference_steps: The number of inference steps to perform.
num_train_timesteps: The number of training timesteps. first_inference_step: The first inference step to perform.
timesteps_spacing: The spacing to use for the timesteps. params: The common parameters for solvers.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule.
first_inference_step: The first inference step.
device: The PyTorch device to use. device: The PyTorch device to use.
dtype: The PyTorch data type to use. dtype: The PyTorch data type to use.
""" """
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
raise NotImplementedError
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
first_inference_step=first_inference_step, first_inference_step=first_inference_step,
params=params,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )

View file

@ -1,6 +1,13 @@
import dataclasses
from torch import Generator, Tensor, device as Device from torch import Generator, Tensor, device as Device
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing from refiners.foundationals.latent_diffusion.solvers.solver import (
ModelPredictionType,
Solver,
SolverParams,
TimestepSpacing,
)
class DDPM(Solver): class DDPM(Solver):
@ -13,37 +20,34 @@ class DDPM(Solver):
See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) for more details. See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) for more details.
""" """
default_params = dataclasses.replace(
Solver.default_params,
timesteps_spacing=TimestepSpacing.LEADING,
)
def __init__( def __init__(
self, self,
num_inference_steps: int, num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
first_inference_step: int = 0, first_inference_step: int = 0,
params: SolverParams | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
) -> None: ) -> None:
"""Initializes a new DDPM solver. """Initializes a new DDPM solver.
Args: Args:
num_inference_steps: The number of inference steps. num_inference_steps: The number of inference steps to perform.
num_train_timesteps: The number of training timesteps. first_inference_step: The first inference step to perform.
timesteps_spacing: The spacing to use for the timesteps. params: The common parameters for solvers.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
first_inference_step: The first inference step.
device: The PyTorch device to use. device: The PyTorch device to use.
""" """
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
raise NotImplementedError
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
first_inference_step=first_inference_step, first_inference_step=first_inference_step,
params=params,
device=device, device=device,
) )

View file

@ -1,8 +1,15 @@
import dataclasses
from collections import deque from collections import deque
import numpy as np
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing from refiners.foundationals.latent_diffusion.solvers.solver import (
ModelPredictionType,
Solver,
SolverParams,
TimestepSpacing,
)
class DPMSolver(Solver): class DPMSolver(Solver):
@ -18,44 +25,37 @@ class DPMSolver(Solver):
for the last step of the diffusion. for the last step of the diffusion.
""" """
default_params = dataclasses.replace(
Solver.default_params,
timesteps_spacing=TimestepSpacing.CUSTOM,
)
def __init__( def __init__(
self, self,
num_inference_steps: int, num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING_ALT,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
last_step_first_order: bool = False,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0, first_inference_step: int = 0,
params: SolverParams | None = None,
last_step_first_order: bool = False,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
): ):
"""Initializes a new DPM solver. """Initializes a new DPM solver.
Args: Args:
num_inference_steps: The number of inference steps. num_inference_steps: The number of inference steps to perform.
num_train_timesteps: The number of training timesteps. first_inference_step: The first inference step to perform.
timesteps_spacing: The spacing to use for the timesteps. params: The common parameters for solvers.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
last_step_first_order: Use a first-order update for the last step. last_step_first_order: Use a first-order update for the last step.
noise_schedule: The noise schedule.
first_inference_step: The first inference step.
device: The PyTorch device to use. device: The PyTorch device to use.
dtype: The PyTorch data type to use. dtype: The PyTorch data type to use.
""" """
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
raise NotImplementedError
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
first_inference_step=first_inference_step, first_inference_step=first_inference_step,
params=params,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -80,6 +80,19 @@ class DPMSolver(Solver):
r.last_step_first_order = self.last_step_first_order r.last_step_first_order = self.last_step_first_order
return r return r
def _generate_timesteps(self) -> Tensor:
if self.params.timesteps_spacing != TimestepSpacing.CUSTOM:
return super()._generate_timesteps()
# We use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# and we want the same result as the original DPM codebase.
offset = self.params.timesteps_offset
max_timestep = self.params.num_train_timesteps - 1 + offset
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
return tensor(np_space).flip(0)
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor: def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
"""Applies a first-order backward Euler update to the input data `x`. """Applies a first-order backward Euler update to the input data `x`.

View file

@ -2,7 +2,12 @@ import numpy as np
import torch import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing from refiners.foundationals.latent_diffusion.solvers.solver import (
ModelPredictionType,
NoiseSchedule,
Solver,
SolverParams,
)
class Euler(Solver): class Euler(Solver):
@ -15,41 +20,27 @@ class Euler(Solver):
def __init__( def __init__(
self, self,
num_inference_steps: int, num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0, first_inference_step: int = 0,
params: SolverParams | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
): ):
"""Initializes a new Euler solver. """Initializes a new Euler solver.
Args: Args:
num_inference_steps: The number of inference steps. num_inference_steps: The number of inference steps to perform.
num_train_timesteps: The number of training timesteps. first_inference_step: The first inference step to perform.
timesteps_spacing: The spacing to use for the timesteps. params: The common parameters for solvers.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule.
first_inference_step: The first inference step.
device: The PyTorch device to use. device: The PyTorch device to use.
dtype: The PyTorch data type to use. dtype: The PyTorch data type to use.
""" """
if noise_schedule != NoiseSchedule.QUADRATIC: if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None):
raise NotImplementedError raise NotImplementedError
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
first_inference_step=first_inference_step, first_inference_step=first_inference_step,
params=params,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -85,7 +76,7 @@ class Euler(Solver):
Args: Args:
x: The input tensor to apply the diffusion process to. x: The input tensor to apply the diffusion process to.
predicted_noise: The predicted noise tensor for the current step. predicted_noise: The predicted noise tensor for the current step (or x0 if the prediction type is SAMPLE).
step: The current step of the diffusion process. step: The current step of the diffusion process.
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic). generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
@ -93,4 +84,11 @@ class Euler(Solver):
The denoised version of the input data `x`. The denoised version of the input data `x`.
""" """
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
if self.params.model_prediction_type == ModelPredictionType.SAMPLE:
x0 = predicted_noise # the model does not actually predict the noise but x0
ratio = self.sigmas[step + 1] / self.sigmas[step]
return ratio * x + (1 - ratio) * x0
assert self.params.model_prediction_type == ModelPredictionType.NOISE
return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step]) return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step])

View file

@ -1,7 +1,14 @@
import dataclasses
import torch import torch
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing from refiners.foundationals.latent_diffusion.solvers.solver import (
ModelPredictionType,
Solver,
SolverParams,
TimestepSpacing,
)
class LCMSolver(Solver): class LCMSolver(Solver):
@ -15,55 +22,52 @@ class LCMSolver(Solver):
for details. for details.
""" """
# The spacing parameter is actually the spacing of the underlying DPM solver.
default_params = dataclasses.replace(Solver.default_params, timesteps_spacing=TimestepSpacing.TRAILING)
def __init__( def __init__(
self, self,
num_inference_steps: int, num_inference_steps: int,
num_train_timesteps: int = 1_000, first_inference_step: int = 0,
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING, params: SolverParams | None = None,
timesteps_offset: int = 0,
num_orig_steps: int = 50, num_orig_steps: int = 50,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: torch.device | str = "cpu", device: torch.device | str = "cpu",
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
): ):
"""Initializes a new LCM solver. """Initializes a new LCM solver.
Args: Args:
num_inference_steps: The number of inference steps. num_inference_steps: The number of inference steps to perform.
num_train_timesteps: The number of training timesteps. first_inference_step: The first inference step to perform.
timesteps_spacing: The spacing to use for the timesteps. params: The common parameters for solvers.
timesteps_offset: The offset to use for the timesteps.
num_orig_steps: The number of inference steps of the emulated DPM solver. num_orig_steps: The number of inference steps of the emulated DPM solver.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule.
device: The PyTorch device to use. device: The PyTorch device to use.
dtype: The PyTorch data type to use. dtype: The PyTorch data type to use.
""" """
assert ( assert (
num_orig_steps >= num_inference_steps num_orig_steps >= num_inference_steps
), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})" ), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})"
params = self.resolve_params(params)
if params.model_prediction_type != ModelPredictionType.NOISE:
raise NotImplementedError
self._dpm = [ self._dpm = [
DPMSolver( DPMSolver(
num_inference_steps=num_orig_steps, num_inference_steps=num_orig_steps,
num_train_timesteps=num_train_timesteps, params=SolverParams(
timesteps_spacing=timesteps_spacing, num_train_timesteps=params.num_train_timesteps,
timesteps_spacing=params.timesteps_spacing,
),
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
] ]
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps, first_inference_step=first_inference_step,
timesteps_spacing=timesteps_spacing, params=params,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )

View file

@ -1,3 +1,4 @@
import dataclasses
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import TypeVar from typing import TypeVar
@ -30,18 +31,64 @@ class TimestepSpacing(str, Enum):
See [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2. See [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
Attributes: Attributes:
LINSPACE_FLOAT: Sample N steps with linear interpolation, return a floating-point tensor. LINSPACE: Sample N steps with linear interpolation, return a floating-point tensor.
LINSPACE_INT: Same as LINSPACE_FLOAT but return an integer tensor with rounded timesteps. LINSPACE_ROUNDED: Same as LINSPACE but return an integer tensor with rounded timesteps.
LEADING: Sample N+1 steps, do not include the last timestep (i.e. bad - non-zero SNR). Used in DDIM. LEADING: Sample N+1 steps, do not include the last timestep (i.e. bad - non-zero SNR). Used in DDIM.
TRAILING: Sample N+1 steps, do not include the first timestep. TRAILING: Sample N+1 steps, do not include the first timestep.
TRAILING_ALT: Variant of TRAILING used in DPM. CUSTOM: Use custom timespacing in solver (override `_generate_timesteps`, see DPM).
""" """
LINSPACE_FLOAT = "linspace_float" LINSPACE = "linspace"
LINSPACE_INT = "linspace_int" LINSPACE_ROUNDED = "linspace_rounded"
LEADING = "leading" LEADING = "leading"
TRAILING = "trailing" TRAILING = "trailing"
TRAILING_ALT = "trailing_alt" CUSTOM = "custom"
class ModelPredictionType(str, Enum):
"""An enumeration of possible outputs of the model.
Attributes:
NOISE: The model predicts the noise (epsilon).
SAMPLE: The model predicts the denoised sample (x0).
"""
NOISE = "noise"
SAMPLE = "sample"
@dataclasses.dataclass(kw_only=True, frozen=True)
class SolverParams:
"""Common parameters for solvers.
Args:
num_train_timesteps: The number of timesteps used to train the diffusion process.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
noise_schedule: The noise schedule used to sample the noise schedule.
model_prediction_type: Defines what the model predicts.
"""
num_train_timesteps: int | None = None
timesteps_spacing: TimestepSpacing | None = None
timesteps_offset: int | None = None
initial_diffusion_rate: float | None = None
final_diffusion_rate: float | None = None
noise_schedule: NoiseSchedule | None = None
model_prediction_type: ModelPredictionType | None = None
@dataclasses.dataclass(kw_only=True, frozen=True)
class ResolvedSolverParams(SolverParams):
num_train_timesteps: int
timesteps_spacing: TimestepSpacing
timesteps_offset: int
initial_diffusion_rate: float
final_diffusion_rate: float
noise_schedule: NoiseSchedule
model_prediction_type: ModelPredictionType
class Solver(fl.Module, ABC): class Solver(fl.Module, ABC):
@ -55,17 +102,23 @@ class Solver(fl.Module, ABC):
""" """
timesteps: Tensor timesteps: Tensor
params: ResolvedSolverParams
default_params = ResolvedSolverParams(
num_train_timesteps=1000,
timesteps_spacing=TimestepSpacing.LINSPACE,
timesteps_offset=0,
initial_diffusion_rate=8.5e-4,
final_diffusion_rate=1.2e-2,
noise_schedule=NoiseSchedule.QUADRATIC,
model_prediction_type=ModelPredictionType.NOISE,
)
def __init__( def __init__(
self, self,
num_inference_steps: int, num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0, first_inference_step: int = 0,
params: SolverParams | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = float32, dtype: DType = float32,
) -> None: ) -> None:
@ -73,32 +126,33 @@ class Solver(fl.Module, ABC):
Args: Args:
num_inference_steps: The number of inference steps to perform. num_inference_steps: The number of inference steps to perform.
num_train_timesteps: The number of timesteps used to train the diffusion process.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
noise_schedule: The noise schedule used to sample the noise schedule.
first_inference_step: The first inference step to perform. first_inference_step: The first inference step to perform.
params: The common parameters for solvers.
device: The PyTorch device to use for the solver's tensors. device: The PyTorch device to use for the solver's tensors.
dtype: The PyTorch data type to use for the solver's tensors. dtype: The PyTorch data type to use for the solver's tensors.
""" """
super().__init__() super().__init__()
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.num_train_timesteps = num_train_timesteps
self.timesteps_spacing = timesteps_spacing
self.timesteps_offset = timesteps_offset
self.initial_diffusion_rate = initial_diffusion_rate
self.final_diffusion_rate = final_diffusion_rate
self.noise_schedule = noise_schedule
self.first_inference_step = first_inference_step self.first_inference_step = first_inference_step
self.params = self.resolve_params(params)
self.scale_factors = self.sample_noise_schedule() self.scale_factors = self.sample_noise_schedule()
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0)) self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0)) self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std) self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
self.timesteps = self._generate_timesteps() self.timesteps = self._generate_timesteps()
self.to(device=device, dtype=dtype) self.to(device=device, dtype=dtype)
def resolve_params(self, params: SolverParams | None) -> ResolvedSolverParams:
if params is None:
return dataclasses.replace(self.default_params)
return dataclasses.replace(
self.default_params,
**{k: v for k, v in dataclasses.asdict(params).items() if v is not None},
)
@abstractmethod @abstractmethod
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""Apply a step of the diffusion process using the Solver. """Apply a step of the diffusion process using the Solver.
@ -131,9 +185,9 @@ class Solver(fl.Module, ABC):
""" """
max_timestep = num_train_timesteps - 1 + offset max_timestep = num_train_timesteps - 1 + offset
match spacing: match spacing:
case TimestepSpacing.LINSPACE_FLOAT: case TimestepSpacing.LINSPACE:
return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0) return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0)
case TimestepSpacing.LINSPACE_INT: case TimestepSpacing.LINSPACE_ROUNDED:
return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0) return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
case TimestepSpacing.LEADING: case TimestepSpacing.LEADING:
step_ratio = num_train_timesteps // num_inference_steps step_ratio = num_train_timesteps // num_inference_steps
@ -142,20 +196,15 @@ class Solver(fl.Module, ABC):
step_ratio = num_train_timesteps // num_inference_steps step_ratio = num_train_timesteps // num_inference_steps
max_timestep = num_train_timesteps - 1 + offset max_timestep = num_train_timesteps - 1 + offset
return arange(max_timestep, offset, -step_ratio) return arange(max_timestep, offset, -step_ratio)
case TimestepSpacing.TRAILING_ALT: case TimestepSpacing.CUSTOM:
# We use numpy here because: raise RuntimeError("generate_timesteps called with custom spacing")
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# and we want the same result as the original DPM codebase.
np_space = np.linspace(offset, max_timestep, num_inference_steps + 1).round().astype(int)[1:]
return tensor(np_space).flip(0)
def _generate_timesteps(self) -> Tensor: def _generate_timesteps(self) -> Tensor:
return self.generate_timesteps( return self.generate_timesteps(
spacing=self.timesteps_spacing, spacing=self.params.timesteps_spacing,
num_inference_steps=self.num_inference_steps, num_inference_steps=self.num_inference_steps,
num_train_timesteps=self.num_train_timesteps, num_train_timesteps=self.params.num_train_timesteps,
offset=self.timesteps_offset, offset=self.params.timesteps_offset,
) )
def add_noise( def add_noise(
@ -239,15 +288,10 @@ class Solver(fl.Module, ABC):
Returns: Returns:
A new solver instance with the specified parameters. A new solver instance with the specified parameters.
""" """
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
return self.__class__( return self.__class__(
num_inference_steps=num_inference_steps, num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,
num_train_timesteps=self.num_train_timesteps, first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,
initial_diffusion_rate=self.initial_diffusion_rate, params=dataclasses.replace(self.params),
final_diffusion_rate=self.final_diffusion_rate,
noise_schedule=self.noise_schedule,
first_inference_step=first_inference_step,
device=self.device, device=self.device,
dtype=self.dtype, dtype=self.dtype,
) )
@ -282,9 +326,9 @@ class Solver(fl.Module, ABC):
""" """
return ( return (
linspace( linspace(
start=self.initial_diffusion_rate ** (1 / power), start=self.params.initial_diffusion_rate ** (1 / power),
end=self.final_diffusion_rate ** (1 / power), end=self.params.final_diffusion_rate ** (1 / power),
steps=self.num_train_timesteps, steps=self.params.num_train_timesteps,
) )
** power ** power
) )
@ -295,7 +339,7 @@ class Solver(fl.Module, ABC):
Returns: Returns:
A tensor representing the noise schedule. A tensor representing the noise schedule.
""" """
match self.noise_schedule: match self.params.noise_schedule:
case "uniform": case "uniform":
return 1 - self.sample_power_distribution(1) return 1 - self.sample_power_distribution(1)
case "quadratic": case "quadratic":
@ -303,7 +347,7 @@ class Solver(fl.Module, ABC):
case "karras": case "karras":
return 1 - self.sample_power_distribution(7) return 1 - self.sample_power_distribution(7)
case _: case _:
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}") raise ValueError(f"Unknown noise schedule: {self.params.noise_schedule}")
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver": def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
"""Move the solver to the specified device and data type. """Move the solver to the specified device and data type.

View file

@ -27,7 +27,7 @@ from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
from refiners.foundationals.latent_diffusion.restart import Restart from refiners.foundationals.latent_diffusion.restart import Restart
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule, SolverParams
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter
@ -600,7 +600,7 @@ def sd15_ddim_karras(
warn("not running on CPU, skipping") warn("not running on CPU, skipping")
pytest.skip() pytest.skip()
ddim_solver = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS) ddim_solver = DDIM(num_inference_steps=20, params=SolverParams(noise_schedule=NoiseSchedule.KARRAS))
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device) sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)

View file

@ -11,8 +11,10 @@ from refiners.foundationals.latent_diffusion.solvers import (
DPMSolver, DPMSolver,
Euler, Euler,
LCMSolver, LCMSolver,
ModelPredictionType,
NoiseSchedule, NoiseSchedule,
Solver, Solver,
SolverParams,
TimestepSpacing, TimestepSpacing,
) )
@ -41,7 +43,10 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0 final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
) )
diffusers_scheduler.set_timesteps(n_steps) diffusers_scheduler.set_timesteps(n_steps)
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order) refiners_scheduler = DPMSolver(
num_inference_steps=n_steps,
last_step_first_order=last_step_first_order,
)
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps) assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 3, 32, 32) sample = randn(1, 3, 32, 32)
@ -80,10 +85,12 @@ def test_ddim_diffusers():
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}" assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
def test_euler_diffusers(): @pytest.mark.parametrize("model_prediction_type", [ModelPredictionType.NOISE, ModelPredictionType.SAMPLE])
def test_euler_diffusers(model_prediction_type: ModelPredictionType):
from diffusers import EulerDiscreteScheduler # type: ignore from diffusers import EulerDiscreteScheduler # type: ignore
manual_seed(0) manual_seed(0)
diffusers_prediction_type = "epsilon" if model_prediction_type == ModelPredictionType.NOISE else "sample"
diffusers_scheduler = EulerDiscreteScheduler( diffusers_scheduler = EulerDiscreteScheduler(
beta_end=0.012, beta_end=0.012,
beta_schedule="scaled_linear", beta_schedule="scaled_linear",
@ -92,9 +99,10 @@ def test_euler_diffusers():
steps_offset=1, steps_offset=1,
timestep_spacing="linspace", timestep_spacing="linspace",
use_karras_sigmas=False, use_karras_sigmas=False,
prediction_type=diffusers_prediction_type,
) )
diffusers_scheduler.set_timesteps(30) diffusers_scheduler.set_timesteps(30)
refiners_scheduler = Euler(num_inference_steps=30) refiners_scheduler = Euler(num_inference_steps=30, params=SolverParams(model_prediction_type=model_prediction_type))
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps) assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32) sample = randn(1, 4, 32, 32)
@ -192,16 +200,20 @@ def test_solver_device(test_device: Device):
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS]) @pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device): def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
scheduler = DDIM(num_inference_steps=30, device=test_device, noise_schedule=noise_schedule) scheduler = DDIM(
num_inference_steps=30,
params=SolverParams(noise_schedule=noise_schedule),
device=test_device,
)
assert len(scheduler.scale_factors) == 1000 assert len(scheduler.scale_factors) == 1000
assert scheduler.scale_factors[0] == 1 - scheduler.initial_diffusion_rate assert scheduler.scale_factors[0] == 1 - scheduler.params.initial_diffusion_rate
assert scheduler.scale_factors[-1] == 1 - scheduler.final_diffusion_rate assert scheduler.scale_factors[-1] == 1 - scheduler.params.final_diffusion_rate
def test_solver_timestep_spacing(): def test_solver_timestep_spacing():
# Tests we get the results from [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2. # Tests we get the results from [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
linspace_int = Solver.generate_timesteps( linspace_int = Solver.generate_timesteps(
spacing=TimestepSpacing.LINSPACE_INT, spacing=TimestepSpacing.LINSPACE_ROUNDED,
num_inference_steps=10, num_inference_steps=10,
num_train_timesteps=1000, num_train_timesteps=1000,
offset=1, offset=1,