refactor solver params, add sample prediction type

This commit is contained in:
Pierre Chapuis 2024-02-22 15:16:22 +01:00
parent ddc1cf8ca7
commit bf0ba58541
9 changed files with 260 additions and 167 deletions

View file

@ -3,6 +3,23 @@ from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
from refiners.foundationals.latent_diffusion.solvers.solver import (
ModelPredictionType,
NoiseSchedule,
Solver,
SolverParams,
TimestepSpacing,
)
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule", "TimestepSpacing"]
__all__ = [
"Solver",
"SolverParams",
"DPMSolver",
"DDPM",
"DDIM",
"Euler",
"LCMSolver",
"ModelPredictionType",
"NoiseSchedule",
"TimestepSpacing",
]

View file

@ -1,6 +1,13 @@
import dataclasses
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
from refiners.foundationals.latent_diffusion.solvers.solver import (
ModelPredictionType,
Solver,
SolverParams,
TimestepSpacing,
)
class DDIM(Solver):
@ -9,42 +16,36 @@ class DDIM(Solver):
See [[arXiv:2010.02502] Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) for more details.
"""
default_params = dataclasses.replace(
Solver.default_params,
timesteps_spacing=TimestepSpacing.LEADING,
timesteps_offset=1,
)
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
timesteps_offset: int = 1,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
params: SolverParams | None = None,
device: Device | str = "cpu",
dtype: Dtype = float32,
) -> None:
"""Initializes a new DDIM solver.
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule.
first_inference_step: The first inference step.
num_inference_steps: The number of inference steps to perform.
first_inference_step: The first inference step to perform.
params: The common parameters for solvers.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
raise NotImplementedError
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
first_inference_step=first_inference_step,
params=params,
device=device,
dtype=dtype,
)

View file

@ -1,6 +1,13 @@
import dataclasses
from torch import Generator, Tensor, device as Device
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
from refiners.foundationals.latent_diffusion.solvers.solver import (
ModelPredictionType,
Solver,
SolverParams,
TimestepSpacing,
)
class DDPM(Solver):
@ -13,37 +20,34 @@ class DDPM(Solver):
See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) for more details.
"""
default_params = dataclasses.replace(
Solver.default_params,
timesteps_spacing=TimestepSpacing.LEADING,
)
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
first_inference_step: int = 0,
params: SolverParams | None = None,
device: Device | str = "cpu",
) -> None:
"""Initializes a new DDPM solver.
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
first_inference_step: The first inference step.
num_inference_steps: The number of inference steps to perform.
first_inference_step: The first inference step to perform.
params: The common parameters for solvers.
device: The PyTorch device to use.
"""
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
raise NotImplementedError
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
first_inference_step=first_inference_step,
params=params,
device=device,
)

View file

@ -1,8 +1,15 @@
import dataclasses
from collections import deque
import numpy as np
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
from refiners.foundationals.latent_diffusion.solvers.solver import (
ModelPredictionType,
Solver,
SolverParams,
TimestepSpacing,
)
class DPMSolver(Solver):
@ -18,44 +25,37 @@ class DPMSolver(Solver):
for the last step of the diffusion.
"""
default_params = dataclasses.replace(
Solver.default_params,
timesteps_spacing=TimestepSpacing.CUSTOM,
)
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING_ALT,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
last_step_first_order: bool = False,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
params: SolverParams | None = None,
last_step_first_order: bool = False,
device: Device | str = "cpu",
dtype: Dtype = float32,
):
"""Initializes a new DPM solver.
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
num_inference_steps: The number of inference steps to perform.
first_inference_step: The first inference step to perform.
params: The common parameters for solvers.
last_step_first_order: Use a first-order update for the last step.
noise_schedule: The noise schedule.
first_inference_step: The first inference step.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
raise NotImplementedError
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
first_inference_step=first_inference_step,
params=params,
device=device,
dtype=dtype,
)
@ -80,6 +80,19 @@ class DPMSolver(Solver):
r.last_step_first_order = self.last_step_first_order
return r
def _generate_timesteps(self) -> Tensor:
if self.params.timesteps_spacing != TimestepSpacing.CUSTOM:
return super()._generate_timesteps()
# We use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# and we want the same result as the original DPM codebase.
offset = self.params.timesteps_offset
max_timestep = self.params.num_train_timesteps - 1 + offset
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
return tensor(np_space).flip(0)
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
"""Applies a first-order backward Euler update to the input data `x`.

View file

@ -2,7 +2,12 @@ import numpy as np
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
from refiners.foundationals.latent_diffusion.solvers.solver import (
ModelPredictionType,
NoiseSchedule,
Solver,
SolverParams,
)
class Euler(Solver):
@ -15,41 +20,27 @@ class Euler(Solver):
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
params: SolverParams | None = None,
device: Device | str = "cpu",
dtype: Dtype = float32,
):
"""Initializes a new Euler solver.
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule.
first_inference_step: The first inference step.
num_inference_steps: The number of inference steps to perform.
first_inference_step: The first inference step to perform.
params: The common parameters for solvers.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
if noise_schedule != NoiseSchedule.QUADRATIC:
if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None):
raise NotImplementedError
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
first_inference_step=first_inference_step,
params=params,
device=device,
dtype=dtype,
)
@ -85,7 +76,7 @@ class Euler(Solver):
Args:
x: The input tensor to apply the diffusion process to.
predicted_noise: The predicted noise tensor for the current step.
predicted_noise: The predicted noise tensor for the current step (or x0 if the prediction type is SAMPLE).
step: The current step of the diffusion process.
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
@ -93,4 +84,11 @@ class Euler(Solver):
The denoised version of the input data `x`.
"""
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
if self.params.model_prediction_type == ModelPredictionType.SAMPLE:
x0 = predicted_noise # the model does not actually predict the noise but x0
ratio = self.sigmas[step + 1] / self.sigmas[step]
return ratio * x + (1 - ratio) * x0
assert self.params.model_prediction_type == ModelPredictionType.NOISE
return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step])

View file

@ -1,7 +1,14 @@
import dataclasses
import torch
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
from refiners.foundationals.latent_diffusion.solvers.solver import (
ModelPredictionType,
Solver,
SolverParams,
TimestepSpacing,
)
class LCMSolver(Solver):
@ -15,55 +22,52 @@ class LCMSolver(Solver):
for details.
"""
# The spacing parameter is actually the spacing of the underlying DPM solver.
default_params = dataclasses.replace(Solver.default_params, timesteps_spacing=TimestepSpacing.TRAILING)
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING,
timesteps_offset: int = 0,
first_inference_step: int = 0,
params: SolverParams | None = None,
num_orig_steps: int = 50,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: torch.device | str = "cpu",
dtype: torch.dtype = torch.float32,
):
"""Initializes a new LCM solver.
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
num_inference_steps: The number of inference steps to perform.
first_inference_step: The first inference step to perform.
params: The common parameters for solvers.
num_orig_steps: The number of inference steps of the emulated DPM solver.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
assert (
num_orig_steps >= num_inference_steps
), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})"
params = self.resolve_params(params)
if params.model_prediction_type != ModelPredictionType.NOISE:
raise NotImplementedError
self._dpm = [
DPMSolver(
num_inference_steps=num_orig_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
params=SolverParams(
num_train_timesteps=params.num_train_timesteps,
timesteps_spacing=params.timesteps_spacing,
),
device=device,
dtype=dtype,
)
]
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
timesteps_spacing=timesteps_spacing,
timesteps_offset=timesteps_offset,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
first_inference_step=first_inference_step,
params=params,
device=device,
dtype=dtype,
)

View file

@ -1,3 +1,4 @@
import dataclasses
from abc import ABC, abstractmethod
from enum import Enum
from typing import TypeVar
@ -30,18 +31,64 @@ class TimestepSpacing(str, Enum):
See [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
Attributes:
LINSPACE_FLOAT: Sample N steps with linear interpolation, return a floating-point tensor.
LINSPACE_INT: Same as LINSPACE_FLOAT but return an integer tensor with rounded timesteps.
LINSPACE: Sample N steps with linear interpolation, return a floating-point tensor.
LINSPACE_ROUNDED: Same as LINSPACE but return an integer tensor with rounded timesteps.
LEADING: Sample N+1 steps, do not include the last timestep (i.e. bad - non-zero SNR). Used in DDIM.
TRAILING: Sample N+1 steps, do not include the first timestep.
TRAILING_ALT: Variant of TRAILING used in DPM.
CUSTOM: Use custom timespacing in solver (override `_generate_timesteps`, see DPM).
"""
LINSPACE_FLOAT = "linspace_float"
LINSPACE_INT = "linspace_int"
LINSPACE = "linspace"
LINSPACE_ROUNDED = "linspace_rounded"
LEADING = "leading"
TRAILING = "trailing"
TRAILING_ALT = "trailing_alt"
CUSTOM = "custom"
class ModelPredictionType(str, Enum):
"""An enumeration of possible outputs of the model.
Attributes:
NOISE: The model predicts the noise (epsilon).
SAMPLE: The model predicts the denoised sample (x0).
"""
NOISE = "noise"
SAMPLE = "sample"
@dataclasses.dataclass(kw_only=True, frozen=True)
class SolverParams:
"""Common parameters for solvers.
Args:
num_train_timesteps: The number of timesteps used to train the diffusion process.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
noise_schedule: The noise schedule used to sample the noise schedule.
model_prediction_type: Defines what the model predicts.
"""
num_train_timesteps: int | None = None
timesteps_spacing: TimestepSpacing | None = None
timesteps_offset: int | None = None
initial_diffusion_rate: float | None = None
final_diffusion_rate: float | None = None
noise_schedule: NoiseSchedule | None = None
model_prediction_type: ModelPredictionType | None = None
@dataclasses.dataclass(kw_only=True, frozen=True)
class ResolvedSolverParams(SolverParams):
num_train_timesteps: int
timesteps_spacing: TimestepSpacing
timesteps_offset: int
initial_diffusion_rate: float
final_diffusion_rate: float
noise_schedule: NoiseSchedule
model_prediction_type: ModelPredictionType
class Solver(fl.Module, ABC):
@ -55,17 +102,23 @@ class Solver(fl.Module, ABC):
"""
timesteps: Tensor
params: ResolvedSolverParams
default_params = ResolvedSolverParams(
num_train_timesteps=1000,
timesteps_spacing=TimestepSpacing.LINSPACE,
timesteps_offset=0,
initial_diffusion_rate=8.5e-4,
final_diffusion_rate=1.2e-2,
noise_schedule=NoiseSchedule.QUADRATIC,
model_prediction_type=ModelPredictionType.NOISE,
)
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
timesteps_offset: int = 0,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
params: SolverParams | None = None,
device: Device | str = "cpu",
dtype: DType = float32,
) -> None:
@ -73,32 +126,33 @@ class Solver(fl.Module, ABC):
Args:
num_inference_steps: The number of inference steps to perform.
num_train_timesteps: The number of timesteps used to train the diffusion process.
timesteps_spacing: The spacing to use for the timesteps.
timesteps_offset: The offset to use for the timesteps.
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
noise_schedule: The noise schedule used to sample the noise schedule.
first_inference_step: The first inference step to perform.
params: The common parameters for solvers.
device: The PyTorch device to use for the solver's tensors.
dtype: The PyTorch data type to use for the solver's tensors.
"""
super().__init__()
self.num_inference_steps = num_inference_steps
self.num_train_timesteps = num_train_timesteps
self.timesteps_spacing = timesteps_spacing
self.timesteps_offset = timesteps_offset
self.initial_diffusion_rate = initial_diffusion_rate
self.final_diffusion_rate = final_diffusion_rate
self.noise_schedule = noise_schedule
self.first_inference_step = first_inference_step
self.params = self.resolve_params(params)
self.scale_factors = self.sample_noise_schedule()
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
self.timesteps = self._generate_timesteps()
self.to(device=device, dtype=dtype)
def resolve_params(self, params: SolverParams | None) -> ResolvedSolverParams:
if params is None:
return dataclasses.replace(self.default_params)
return dataclasses.replace(
self.default_params,
**{k: v for k, v in dataclasses.asdict(params).items() if v is not None},
)
@abstractmethod
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""Apply a step of the diffusion process using the Solver.
@ -131,9 +185,9 @@ class Solver(fl.Module, ABC):
"""
max_timestep = num_train_timesteps - 1 + offset
match spacing:
case TimestepSpacing.LINSPACE_FLOAT:
case TimestepSpacing.LINSPACE:
return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0)
case TimestepSpacing.LINSPACE_INT:
case TimestepSpacing.LINSPACE_ROUNDED:
return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
case TimestepSpacing.LEADING:
step_ratio = num_train_timesteps // num_inference_steps
@ -142,20 +196,15 @@ class Solver(fl.Module, ABC):
step_ratio = num_train_timesteps // num_inference_steps
max_timestep = num_train_timesteps - 1 + offset
return arange(max_timestep, offset, -step_ratio)
case TimestepSpacing.TRAILING_ALT:
# We use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# and we want the same result as the original DPM codebase.
np_space = np.linspace(offset, max_timestep, num_inference_steps + 1).round().astype(int)[1:]
return tensor(np_space).flip(0)
case TimestepSpacing.CUSTOM:
raise RuntimeError("generate_timesteps called with custom spacing")
def _generate_timesteps(self) -> Tensor:
return self.generate_timesteps(
spacing=self.timesteps_spacing,
spacing=self.params.timesteps_spacing,
num_inference_steps=self.num_inference_steps,
num_train_timesteps=self.num_train_timesteps,
offset=self.timesteps_offset,
num_train_timesteps=self.params.num_train_timesteps,
offset=self.params.timesteps_offset,
)
def add_noise(
@ -239,15 +288,10 @@ class Solver(fl.Module, ABC):
Returns:
A new solver instance with the specified parameters.
"""
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
return self.__class__(
num_inference_steps=num_inference_steps,
num_train_timesteps=self.num_train_timesteps,
initial_diffusion_rate=self.initial_diffusion_rate,
final_diffusion_rate=self.final_diffusion_rate,
noise_schedule=self.noise_schedule,
first_inference_step=first_inference_step,
num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,
first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,
params=dataclasses.replace(self.params),
device=self.device,
dtype=self.dtype,
)
@ -282,9 +326,9 @@ class Solver(fl.Module, ABC):
"""
return (
linspace(
start=self.initial_diffusion_rate ** (1 / power),
end=self.final_diffusion_rate ** (1 / power),
steps=self.num_train_timesteps,
start=self.params.initial_diffusion_rate ** (1 / power),
end=self.params.final_diffusion_rate ** (1 / power),
steps=self.params.num_train_timesteps,
)
** power
)
@ -295,7 +339,7 @@ class Solver(fl.Module, ABC):
Returns:
A tensor representing the noise schedule.
"""
match self.noise_schedule:
match self.params.noise_schedule:
case "uniform":
return 1 - self.sample_power_distribution(1)
case "quadratic":
@ -303,7 +347,7 @@ class Solver(fl.Module, ABC):
case "karras":
return 1 - self.sample_power_distribution(7)
case _:
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
raise ValueError(f"Unknown noise schedule: {self.params.noise_schedule}")
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
"""Move the solver to the specified device and data type.

View file

@ -27,7 +27,7 @@ from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
from refiners.foundationals.latent_diffusion.restart import Restart
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule, SolverParams
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter
@ -600,7 +600,7 @@ def sd15_ddim_karras(
warn("not running on CPU, skipping")
pytest.skip()
ddim_solver = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS)
ddim_solver = DDIM(num_inference_steps=20, params=SolverParams(noise_schedule=NoiseSchedule.KARRAS))
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)

View file

@ -11,8 +11,10 @@ from refiners.foundationals.latent_diffusion.solvers import (
DPMSolver,
Euler,
LCMSolver,
ModelPredictionType,
NoiseSchedule,
Solver,
SolverParams,
TimestepSpacing,
)
@ -41,7 +43,10 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
)
diffusers_scheduler.set_timesteps(n_steps)
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)
refiners_scheduler = DPMSolver(
num_inference_steps=n_steps,
last_step_first_order=last_step_first_order,
)
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 3, 32, 32)
@ -80,10 +85,12 @@ def test_ddim_diffusers():
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
def test_euler_diffusers():
@pytest.mark.parametrize("model_prediction_type", [ModelPredictionType.NOISE, ModelPredictionType.SAMPLE])
def test_euler_diffusers(model_prediction_type: ModelPredictionType):
from diffusers import EulerDiscreteScheduler # type: ignore
manual_seed(0)
diffusers_prediction_type = "epsilon" if model_prediction_type == ModelPredictionType.NOISE else "sample"
diffusers_scheduler = EulerDiscreteScheduler(
beta_end=0.012,
beta_schedule="scaled_linear",
@ -92,9 +99,10 @@ def test_euler_diffusers():
steps_offset=1,
timestep_spacing="linspace",
use_karras_sigmas=False,
prediction_type=diffusers_prediction_type,
)
diffusers_scheduler.set_timesteps(30)
refiners_scheduler = Euler(num_inference_steps=30)
refiners_scheduler = Euler(num_inference_steps=30, params=SolverParams(model_prediction_type=model_prediction_type))
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32)
@ -192,16 +200,20 @@ def test_solver_device(test_device: Device):
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
scheduler = DDIM(num_inference_steps=30, device=test_device, noise_schedule=noise_schedule)
scheduler = DDIM(
num_inference_steps=30,
params=SolverParams(noise_schedule=noise_schedule),
device=test_device,
)
assert len(scheduler.scale_factors) == 1000
assert scheduler.scale_factors[0] == 1 - scheduler.initial_diffusion_rate
assert scheduler.scale_factors[-1] == 1 - scheduler.final_diffusion_rate
assert scheduler.scale_factors[0] == 1 - scheduler.params.initial_diffusion_rate
assert scheduler.scale_factors[-1] == 1 - scheduler.params.final_diffusion_rate
def test_solver_timestep_spacing():
# Tests we get the results from [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
linspace_int = Solver.generate_timesteps(
spacing=TimestepSpacing.LINSPACE_INT,
spacing=TimestepSpacing.LINSPACE_ROUNDED,
num_inference_steps=10,
num_train_timesteps=1000,
offset=1,