mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
refactor solver params, add sample prediction type
This commit is contained in:
parent
ddc1cf8ca7
commit
bf0ba58541
|
@ -3,6 +3,23 @@ from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
|
|||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
|
||||
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
ModelPredictionType,
|
||||
NoiseSchedule,
|
||||
Solver,
|
||||
SolverParams,
|
||||
TimestepSpacing,
|
||||
)
|
||||
|
||||
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule", "TimestepSpacing"]
|
||||
__all__ = [
|
||||
"Solver",
|
||||
"SolverParams",
|
||||
"DPMSolver",
|
||||
"DDPM",
|
||||
"DDIM",
|
||||
"Euler",
|
||||
"LCMSolver",
|
||||
"ModelPredictionType",
|
||||
"NoiseSchedule",
|
||||
"TimestepSpacing",
|
||||
]
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
import dataclasses
|
||||
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
ModelPredictionType,
|
||||
Solver,
|
||||
SolverParams,
|
||||
TimestepSpacing,
|
||||
)
|
||||
|
||||
|
||||
class DDIM(Solver):
|
||||
|
@ -9,42 +16,36 @@ class DDIM(Solver):
|
|||
See [[arXiv:2010.02502] Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) for more details.
|
||||
"""
|
||||
|
||||
default_params = dataclasses.replace(
|
||||
Solver.default_params,
|
||||
timesteps_spacing=TimestepSpacing.LEADING,
|
||||
timesteps_offset=1,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
|
||||
timesteps_offset: int = 1,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
first_inference_step: int = 0,
|
||||
params: SolverParams | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
) -> None:
|
||||
"""Initializes a new DDIM solver.
|
||||
|
||||
Args:
|
||||
num_inference_steps: The number of inference steps.
|
||||
num_train_timesteps: The number of training timesteps.
|
||||
timesteps_spacing: The spacing to use for the timesteps.
|
||||
timesteps_offset: The offset to use for the timesteps.
|
||||
initial_diffusion_rate: The initial diffusion rate.
|
||||
final_diffusion_rate: The final diffusion rate.
|
||||
noise_schedule: The noise schedule.
|
||||
first_inference_step: The first inference step.
|
||||
num_inference_steps: The number of inference steps to perform.
|
||||
first_inference_step: The first inference step to perform.
|
||||
params: The common parameters for solvers.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
|
||||
raise NotImplementedError
|
||||
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
timesteps_spacing=timesteps_spacing,
|
||||
timesteps_offset=timesteps_offset,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
first_inference_step=first_inference_step,
|
||||
params=params,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
import dataclasses
|
||||
|
||||
from torch import Generator, Tensor, device as Device
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
ModelPredictionType,
|
||||
Solver,
|
||||
SolverParams,
|
||||
TimestepSpacing,
|
||||
)
|
||||
|
||||
|
||||
class DDPM(Solver):
|
||||
|
@ -13,37 +20,34 @@ class DDPM(Solver):
|
|||
See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) for more details.
|
||||
"""
|
||||
|
||||
default_params = dataclasses.replace(
|
||||
Solver.default_params,
|
||||
timesteps_spacing=TimestepSpacing.LEADING,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
|
||||
timesteps_offset: int = 0,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
first_inference_step: int = 0,
|
||||
params: SolverParams | None = None,
|
||||
device: Device | str = "cpu",
|
||||
) -> None:
|
||||
"""Initializes a new DDPM solver.
|
||||
|
||||
Args:
|
||||
num_inference_steps: The number of inference steps.
|
||||
num_train_timesteps: The number of training timesteps.
|
||||
timesteps_spacing: The spacing to use for the timesteps.
|
||||
timesteps_offset: The offset to use for the timesteps.
|
||||
initial_diffusion_rate: The initial diffusion rate.
|
||||
final_diffusion_rate: The final diffusion rate.
|
||||
first_inference_step: The first inference step.
|
||||
num_inference_steps: The number of inference steps to perform.
|
||||
first_inference_step: The first inference step to perform.
|
||||
params: The common parameters for solvers.
|
||||
device: The PyTorch device to use.
|
||||
"""
|
||||
|
||||
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
|
||||
raise NotImplementedError
|
||||
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
timesteps_spacing=timesteps_spacing,
|
||||
timesteps_offset=timesteps_offset,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
first_inference_step=first_inference_step,
|
||||
params=params,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
|
|
@ -1,8 +1,15 @@
|
|||
import dataclasses
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
ModelPredictionType,
|
||||
Solver,
|
||||
SolverParams,
|
||||
TimestepSpacing,
|
||||
)
|
||||
|
||||
|
||||
class DPMSolver(Solver):
|
||||
|
@ -18,44 +25,37 @@ class DPMSolver(Solver):
|
|||
for the last step of the diffusion.
|
||||
"""
|
||||
|
||||
default_params = dataclasses.replace(
|
||||
Solver.default_params,
|
||||
timesteps_spacing=TimestepSpacing.CUSTOM,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING_ALT,
|
||||
timesteps_offset: int = 0,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
last_step_first_order: bool = False,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
first_inference_step: int = 0,
|
||||
params: SolverParams | None = None,
|
||||
last_step_first_order: bool = False,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
):
|
||||
"""Initializes a new DPM solver.
|
||||
|
||||
Args:
|
||||
num_inference_steps: The number of inference steps.
|
||||
num_train_timesteps: The number of training timesteps.
|
||||
timesteps_spacing: The spacing to use for the timesteps.
|
||||
timesteps_offset: The offset to use for the timesteps.
|
||||
initial_diffusion_rate: The initial diffusion rate.
|
||||
final_diffusion_rate: The final diffusion rate.
|
||||
num_inference_steps: The number of inference steps to perform.
|
||||
first_inference_step: The first inference step to perform.
|
||||
params: The common parameters for solvers.
|
||||
last_step_first_order: Use a first-order update for the last step.
|
||||
noise_schedule: The noise schedule.
|
||||
first_inference_step: The first inference step.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
|
||||
raise NotImplementedError
|
||||
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
timesteps_spacing=timesteps_spacing,
|
||||
timesteps_offset=timesteps_offset,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
first_inference_step=first_inference_step,
|
||||
params=params,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -80,6 +80,19 @@ class DPMSolver(Solver):
|
|||
r.last_step_first_order = self.last_step_first_order
|
||||
return r
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
if self.params.timesteps_spacing != TimestepSpacing.CUSTOM:
|
||||
return super()._generate_timesteps()
|
||||
|
||||
# We use numpy here because:
|
||||
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||
# torch.linspace(0,999,31)[15] is 499.5
|
||||
# and we want the same result as the original DPM codebase.
|
||||
offset = self.params.timesteps_offset
|
||||
max_timestep = self.params.num_train_timesteps - 1 + offset
|
||||
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
|
||||
return tensor(np_space).flip(0)
|
||||
|
||||
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||
"""Applies a first-order backward Euler update to the input data `x`.
|
||||
|
||||
|
|
|
@ -2,7 +2,12 @@ import numpy as np
|
|||
import torch
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
ModelPredictionType,
|
||||
NoiseSchedule,
|
||||
Solver,
|
||||
SolverParams,
|
||||
)
|
||||
|
||||
|
||||
class Euler(Solver):
|
||||
|
@ -15,41 +20,27 @@ class Euler(Solver):
|
|||
def __init__(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
|
||||
timesteps_offset: int = 0,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
first_inference_step: int = 0,
|
||||
params: SolverParams | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
):
|
||||
"""Initializes a new Euler solver.
|
||||
|
||||
Args:
|
||||
num_inference_steps: The number of inference steps.
|
||||
num_train_timesteps: The number of training timesteps.
|
||||
timesteps_spacing: The spacing to use for the timesteps.
|
||||
timesteps_offset: The offset to use for the timesteps.
|
||||
initial_diffusion_rate: The initial diffusion rate.
|
||||
final_diffusion_rate: The final diffusion rate.
|
||||
noise_schedule: The noise schedule.
|
||||
first_inference_step: The first inference step.
|
||||
num_inference_steps: The number of inference steps to perform.
|
||||
first_inference_step: The first inference step to perform.
|
||||
params: The common parameters for solvers.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
if noise_schedule != NoiseSchedule.QUADRATIC:
|
||||
if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None):
|
||||
raise NotImplementedError
|
||||
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
timesteps_spacing=timesteps_spacing,
|
||||
timesteps_offset=timesteps_offset,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
first_inference_step=first_inference_step,
|
||||
params=params,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -85,7 +76,7 @@ class Euler(Solver):
|
|||
|
||||
Args:
|
||||
x: The input tensor to apply the diffusion process to.
|
||||
predicted_noise: The predicted noise tensor for the current step.
|
||||
predicted_noise: The predicted noise tensor for the current step (or x0 if the prediction type is SAMPLE).
|
||||
step: The current step of the diffusion process.
|
||||
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
|
||||
|
||||
|
@ -93,4 +84,11 @@ class Euler(Solver):
|
|||
The denoised version of the input data `x`.
|
||||
"""
|
||||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||
|
||||
if self.params.model_prediction_type == ModelPredictionType.SAMPLE:
|
||||
x0 = predicted_noise # the model does not actually predict the noise but x0
|
||||
ratio = self.sigmas[step + 1] / self.sigmas[step]
|
||||
return ratio * x + (1 - ratio) * x0
|
||||
|
||||
assert self.params.model_prediction_type == ModelPredictionType.NOISE
|
||||
return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step])
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
import dataclasses
|
||||
|
||||
import torch
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
ModelPredictionType,
|
||||
Solver,
|
||||
SolverParams,
|
||||
TimestepSpacing,
|
||||
)
|
||||
|
||||
|
||||
class LCMSolver(Solver):
|
||||
|
@ -15,55 +22,52 @@ class LCMSolver(Solver):
|
|||
for details.
|
||||
"""
|
||||
|
||||
# The spacing parameter is actually the spacing of the underlying DPM solver.
|
||||
default_params = dataclasses.replace(Solver.default_params, timesteps_spacing=TimestepSpacing.TRAILING)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING,
|
||||
timesteps_offset: int = 0,
|
||||
first_inference_step: int = 0,
|
||||
params: SolverParams | None = None,
|
||||
num_orig_steps: int = 50,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
device: torch.device | str = "cpu",
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
"""Initializes a new LCM solver.
|
||||
|
||||
Args:
|
||||
num_inference_steps: The number of inference steps.
|
||||
num_train_timesteps: The number of training timesteps.
|
||||
timesteps_spacing: The spacing to use for the timesteps.
|
||||
timesteps_offset: The offset to use for the timesteps.
|
||||
num_inference_steps: The number of inference steps to perform.
|
||||
first_inference_step: The first inference step to perform.
|
||||
params: The common parameters for solvers.
|
||||
num_orig_steps: The number of inference steps of the emulated DPM solver.
|
||||
initial_diffusion_rate: The initial diffusion rate.
|
||||
final_diffusion_rate: The final diffusion rate.
|
||||
noise_schedule: The noise schedule.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
|
||||
assert (
|
||||
num_orig_steps >= num_inference_steps
|
||||
), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})"
|
||||
|
||||
params = self.resolve_params(params)
|
||||
if params.model_prediction_type != ModelPredictionType.NOISE:
|
||||
raise NotImplementedError
|
||||
|
||||
self._dpm = [
|
||||
DPMSolver(
|
||||
num_inference_steps=num_orig_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
timesteps_spacing=timesteps_spacing,
|
||||
params=SolverParams(
|
||||
num_train_timesteps=params.num_train_timesteps,
|
||||
timesteps_spacing=params.timesteps_spacing,
|
||||
),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
timesteps_spacing=timesteps_spacing,
|
||||
timesteps_offset=timesteps_offset,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
first_inference_step=first_inference_step,
|
||||
params=params,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import TypeVar
|
||||
|
@ -30,18 +31,64 @@ class TimestepSpacing(str, Enum):
|
|||
See [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
|
||||
|
||||
Attributes:
|
||||
LINSPACE_FLOAT: Sample N steps with linear interpolation, return a floating-point tensor.
|
||||
LINSPACE_INT: Same as LINSPACE_FLOAT but return an integer tensor with rounded timesteps.
|
||||
LINSPACE: Sample N steps with linear interpolation, return a floating-point tensor.
|
||||
LINSPACE_ROUNDED: Same as LINSPACE but return an integer tensor with rounded timesteps.
|
||||
LEADING: Sample N+1 steps, do not include the last timestep (i.e. bad - non-zero SNR). Used in DDIM.
|
||||
TRAILING: Sample N+1 steps, do not include the first timestep.
|
||||
TRAILING_ALT: Variant of TRAILING used in DPM.
|
||||
CUSTOM: Use custom timespacing in solver (override `_generate_timesteps`, see DPM).
|
||||
"""
|
||||
|
||||
LINSPACE_FLOAT = "linspace_float"
|
||||
LINSPACE_INT = "linspace_int"
|
||||
LINSPACE = "linspace"
|
||||
LINSPACE_ROUNDED = "linspace_rounded"
|
||||
LEADING = "leading"
|
||||
TRAILING = "trailing"
|
||||
TRAILING_ALT = "trailing_alt"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class ModelPredictionType(str, Enum):
|
||||
"""An enumeration of possible outputs of the model.
|
||||
|
||||
Attributes:
|
||||
NOISE: The model predicts the noise (epsilon).
|
||||
SAMPLE: The model predicts the denoised sample (x0).
|
||||
"""
|
||||
|
||||
NOISE = "noise"
|
||||
SAMPLE = "sample"
|
||||
|
||||
|
||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||
class SolverParams:
|
||||
"""Common parameters for solvers.
|
||||
|
||||
Args:
|
||||
num_train_timesteps: The number of timesteps used to train the diffusion process.
|
||||
timesteps_spacing: The spacing to use for the timesteps.
|
||||
timesteps_offset: The offset to use for the timesteps.
|
||||
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
|
||||
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
|
||||
noise_schedule: The noise schedule used to sample the noise schedule.
|
||||
model_prediction_type: Defines what the model predicts.
|
||||
"""
|
||||
|
||||
num_train_timesteps: int | None = None
|
||||
timesteps_spacing: TimestepSpacing | None = None
|
||||
timesteps_offset: int | None = None
|
||||
initial_diffusion_rate: float | None = None
|
||||
final_diffusion_rate: float | None = None
|
||||
noise_schedule: NoiseSchedule | None = None
|
||||
model_prediction_type: ModelPredictionType | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||
class ResolvedSolverParams(SolverParams):
|
||||
num_train_timesteps: int
|
||||
timesteps_spacing: TimestepSpacing
|
||||
timesteps_offset: int
|
||||
initial_diffusion_rate: float
|
||||
final_diffusion_rate: float
|
||||
noise_schedule: NoiseSchedule
|
||||
model_prediction_type: ModelPredictionType
|
||||
|
||||
|
||||
class Solver(fl.Module, ABC):
|
||||
|
@ -55,17 +102,23 @@ class Solver(fl.Module, ABC):
|
|||
"""
|
||||
|
||||
timesteps: Tensor
|
||||
params: ResolvedSolverParams
|
||||
|
||||
default_params = ResolvedSolverParams(
|
||||
num_train_timesteps=1000,
|
||||
timesteps_spacing=TimestepSpacing.LINSPACE,
|
||||
timesteps_offset=0,
|
||||
initial_diffusion_rate=8.5e-4,
|
||||
final_diffusion_rate=1.2e-2,
|
||||
noise_schedule=NoiseSchedule.QUADRATIC,
|
||||
model_prediction_type=ModelPredictionType.NOISE,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
|
||||
timesteps_offset: int = 0,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
first_inference_step: int = 0,
|
||||
params: SolverParams | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = float32,
|
||||
) -> None:
|
||||
|
@ -73,32 +126,33 @@ class Solver(fl.Module, ABC):
|
|||
|
||||
Args:
|
||||
num_inference_steps: The number of inference steps to perform.
|
||||
num_train_timesteps: The number of timesteps used to train the diffusion process.
|
||||
timesteps_spacing: The spacing to use for the timesteps.
|
||||
timesteps_offset: The offset to use for the timesteps.
|
||||
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
|
||||
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
|
||||
noise_schedule: The noise schedule used to sample the noise schedule.
|
||||
first_inference_step: The first inference step to perform.
|
||||
params: The common parameters for solvers.
|
||||
device: The PyTorch device to use for the solver's tensors.
|
||||
dtype: The PyTorch data type to use for the solver's tensors.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.timesteps_spacing = timesteps_spacing
|
||||
self.timesteps_offset = timesteps_offset
|
||||
self.initial_diffusion_rate = initial_diffusion_rate
|
||||
self.final_diffusion_rate = final_diffusion_rate
|
||||
self.noise_schedule = noise_schedule
|
||||
self.first_inference_step = first_inference_step
|
||||
self.params = self.resolve_params(params)
|
||||
|
||||
self.scale_factors = self.sample_noise_schedule()
|
||||
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
|
||||
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
||||
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
|
||||
self.timesteps = self._generate_timesteps()
|
||||
|
||||
self.to(device=device, dtype=dtype)
|
||||
|
||||
def resolve_params(self, params: SolverParams | None) -> ResolvedSolverParams:
|
||||
if params is None:
|
||||
return dataclasses.replace(self.default_params)
|
||||
return dataclasses.replace(
|
||||
self.default_params,
|
||||
**{k: v for k, v in dataclasses.asdict(params).items() if v is not None},
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
"""Apply a step of the diffusion process using the Solver.
|
||||
|
@ -131,9 +185,9 @@ class Solver(fl.Module, ABC):
|
|||
"""
|
||||
max_timestep = num_train_timesteps - 1 + offset
|
||||
match spacing:
|
||||
case TimestepSpacing.LINSPACE_FLOAT:
|
||||
case TimestepSpacing.LINSPACE:
|
||||
return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0)
|
||||
case TimestepSpacing.LINSPACE_INT:
|
||||
case TimestepSpacing.LINSPACE_ROUNDED:
|
||||
return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
|
||||
case TimestepSpacing.LEADING:
|
||||
step_ratio = num_train_timesteps // num_inference_steps
|
||||
|
@ -142,20 +196,15 @@ class Solver(fl.Module, ABC):
|
|||
step_ratio = num_train_timesteps // num_inference_steps
|
||||
max_timestep = num_train_timesteps - 1 + offset
|
||||
return arange(max_timestep, offset, -step_ratio)
|
||||
case TimestepSpacing.TRAILING_ALT:
|
||||
# We use numpy here because:
|
||||
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||
# torch.linspace(0,999,31)[15] is 499.5
|
||||
# and we want the same result as the original DPM codebase.
|
||||
np_space = np.linspace(offset, max_timestep, num_inference_steps + 1).round().astype(int)[1:]
|
||||
return tensor(np_space).flip(0)
|
||||
case TimestepSpacing.CUSTOM:
|
||||
raise RuntimeError("generate_timesteps called with custom spacing")
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
return self.generate_timesteps(
|
||||
spacing=self.timesteps_spacing,
|
||||
spacing=self.params.timesteps_spacing,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
offset=self.timesteps_offset,
|
||||
num_train_timesteps=self.params.num_train_timesteps,
|
||||
offset=self.params.timesteps_offset,
|
||||
)
|
||||
|
||||
def add_noise(
|
||||
|
@ -239,15 +288,10 @@ class Solver(fl.Module, ABC):
|
|||
Returns:
|
||||
A new solver instance with the specified parameters.
|
||||
"""
|
||||
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
|
||||
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
|
||||
return self.__class__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
initial_diffusion_rate=self.initial_diffusion_rate,
|
||||
final_diffusion_rate=self.final_diffusion_rate,
|
||||
noise_schedule=self.noise_schedule,
|
||||
first_inference_step=first_inference_step,
|
||||
num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,
|
||||
first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,
|
||||
params=dataclasses.replace(self.params),
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
@ -282,9 +326,9 @@ class Solver(fl.Module, ABC):
|
|||
"""
|
||||
return (
|
||||
linspace(
|
||||
start=self.initial_diffusion_rate ** (1 / power),
|
||||
end=self.final_diffusion_rate ** (1 / power),
|
||||
steps=self.num_train_timesteps,
|
||||
start=self.params.initial_diffusion_rate ** (1 / power),
|
||||
end=self.params.final_diffusion_rate ** (1 / power),
|
||||
steps=self.params.num_train_timesteps,
|
||||
)
|
||||
** power
|
||||
)
|
||||
|
@ -295,7 +339,7 @@ class Solver(fl.Module, ABC):
|
|||
Returns:
|
||||
A tensor representing the noise schedule.
|
||||
"""
|
||||
match self.noise_schedule:
|
||||
match self.params.noise_schedule:
|
||||
case "uniform":
|
||||
return 1 - self.sample_power_distribution(1)
|
||||
case "quadratic":
|
||||
|
@ -303,7 +347,7 @@ class Solver(fl.Module, ABC):
|
|||
case "karras":
|
||||
return 1 - self.sample_power_distribution(7)
|
||||
case _:
|
||||
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
|
||||
raise ValueError(f"Unknown noise schedule: {self.params.noise_schedule}")
|
||||
|
||||
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
|
||||
"""Move the solver to the specified device and data type.
|
||||
|
|
|
@ -27,7 +27,7 @@ from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
|||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
|
||||
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
||||
from refiners.foundationals.latent_diffusion.restart import Restart
|
||||
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule
|
||||
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule, SolverParams
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||
from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter
|
||||
|
@ -600,7 +600,7 @@ def sd15_ddim_karras(
|
|||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
ddim_solver = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS)
|
||||
ddim_solver = DDIM(num_inference_steps=20, params=SolverParams(noise_schedule=NoiseSchedule.KARRAS))
|
||||
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
|
||||
|
||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||
|
|
|
@ -11,8 +11,10 @@ from refiners.foundationals.latent_diffusion.solvers import (
|
|||
DPMSolver,
|
||||
Euler,
|
||||
LCMSolver,
|
||||
ModelPredictionType,
|
||||
NoiseSchedule,
|
||||
Solver,
|
||||
SolverParams,
|
||||
TimestepSpacing,
|
||||
)
|
||||
|
||||
|
@ -41,7 +43,10 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
|
|||
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
|
||||
)
|
||||
diffusers_scheduler.set_timesteps(n_steps)
|
||||
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)
|
||||
refiners_scheduler = DPMSolver(
|
||||
num_inference_steps=n_steps,
|
||||
last_step_first_order=last_step_first_order,
|
||||
)
|
||||
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||
|
||||
sample = randn(1, 3, 32, 32)
|
||||
|
@ -80,10 +85,12 @@ def test_ddim_diffusers():
|
|||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||
|
||||
|
||||
def test_euler_diffusers():
|
||||
@pytest.mark.parametrize("model_prediction_type", [ModelPredictionType.NOISE, ModelPredictionType.SAMPLE])
|
||||
def test_euler_diffusers(model_prediction_type: ModelPredictionType):
|
||||
from diffusers import EulerDiscreteScheduler # type: ignore
|
||||
|
||||
manual_seed(0)
|
||||
diffusers_prediction_type = "epsilon" if model_prediction_type == ModelPredictionType.NOISE else "sample"
|
||||
diffusers_scheduler = EulerDiscreteScheduler(
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
|
@ -92,9 +99,10 @@ def test_euler_diffusers():
|
|||
steps_offset=1,
|
||||
timestep_spacing="linspace",
|
||||
use_karras_sigmas=False,
|
||||
prediction_type=diffusers_prediction_type,
|
||||
)
|
||||
diffusers_scheduler.set_timesteps(30)
|
||||
refiners_scheduler = Euler(num_inference_steps=30)
|
||||
refiners_scheduler = Euler(num_inference_steps=30, params=SolverParams(model_prediction_type=model_prediction_type))
|
||||
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||
|
||||
sample = randn(1, 4, 32, 32)
|
||||
|
@ -192,16 +200,20 @@ def test_solver_device(test_device: Device):
|
|||
|
||||
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
|
||||
def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
|
||||
scheduler = DDIM(num_inference_steps=30, device=test_device, noise_schedule=noise_schedule)
|
||||
scheduler = DDIM(
|
||||
num_inference_steps=30,
|
||||
params=SolverParams(noise_schedule=noise_schedule),
|
||||
device=test_device,
|
||||
)
|
||||
assert len(scheduler.scale_factors) == 1000
|
||||
assert scheduler.scale_factors[0] == 1 - scheduler.initial_diffusion_rate
|
||||
assert scheduler.scale_factors[-1] == 1 - scheduler.final_diffusion_rate
|
||||
assert scheduler.scale_factors[0] == 1 - scheduler.params.initial_diffusion_rate
|
||||
assert scheduler.scale_factors[-1] == 1 - scheduler.params.final_diffusion_rate
|
||||
|
||||
|
||||
def test_solver_timestep_spacing():
|
||||
# Tests we get the results from [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
|
||||
linspace_int = Solver.generate_timesteps(
|
||||
spacing=TimestepSpacing.LINSPACE_INT,
|
||||
spacing=TimestepSpacing.LINSPACE_ROUNDED,
|
||||
num_inference_steps=10,
|
||||
num_train_timesteps=1000,
|
||||
offset=1,
|
||||
|
|
Loading…
Reference in a new issue