(doc/fluxion/ld) add DDPM, DDIM, DPM++ and Euleur docstrings

This commit is contained in:
Laurent 2024-02-02 12:59:30 +00:00 committed by Laureηt
parent 24c11745cd
commit ff17991261
4 changed files with 132 additions and 26 deletions

View file

@ -4,6 +4,11 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
class DDIM(Solver):
"""Denoising Diffusion Implicit Model (DDIM) solver.
See [[arXiv:2010.02502] Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) for more details.
"""
def __init__(
self,
num_inference_steps: int,
@ -15,6 +20,18 @@ class DDIM(Solver):
device: Device | str = "cpu",
dtype: Dtype = float32,
) -> None:
"""Initializes a new DDIM solver.
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule.
first_inference_step: The first inference step.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,

View file

@ -1,13 +1,16 @@
from torch import Generator, Tensor, arange, device as Device, dtype as DType
from torch import Generator, Tensor, arange, device as Device
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
class DDPM(Solver):
"""
Denoising Diffusion Probabilistic Model
"""Denoising Diffusion Probabilistic Model (DDPM) solver.
Only used for training Latent Diffusion models. Cannot be called.
Warning:
Only used for training Latent Diffusion models.
Cannot be called.
See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) for more details.
"""
def __init__(
@ -16,11 +19,19 @@ class DDPM(Solver):
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule | None = None, # ignored
first_inference_step: int = 0,
device: Device | str = "cpu",
dtype: DType | None = None, # ignored
) -> None:
"""Initializes a new DDPM solver.
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
first_inference_step: The first inference step.
device: The PyTorch device to use.
"""
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,

View file

@ -7,13 +7,16 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
class DPMSolver(Solver):
"""
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
"""Diffusion probabilistic models (DPMs) solver.
Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts
when used with SDXL and few steps. This parameter is a way to mitigate that
effect by using a first-order (Euler) update instead of a second-order update
for the last step of the diffusion.
See [[arXiv:2211.01095] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://arxiv.org/abs/2211.01095)
for more details.
Note:
Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts
when used with SDXL and few steps. This parameter is a way to mitigate that
effect by using a first-order (Euler) update instead of a second-order update
for the last step of the diffusion.
"""
def __init__(
@ -42,10 +45,16 @@ class DPMSolver(Solver):
self.last_step_first_order = last_step_first_order
def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase.
"""Generate the timesteps used by the solver.
Note:
We need to use numpy here because:
- numpy.linspace(0,999,31)[15] is 499.49999999999994
- torch.linspace(0,999,31)[15] is 499.5
and we want the same result as the original codebase.
"""
return tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
).flip(0)
@ -55,6 +64,12 @@ class DPMSolver(Solver):
num_inference_steps: int | None,
first_inference_step: int | None = None,
) -> "DPMSolver":
"""Rebuilds the solver with new parameters.
Args:
num_inference_steps: The number of inference steps.
first_inference_step: The first inference step.
"""
r = super().rebuild(
num_inference_steps=num_inference_steps,
first_inference_step=first_inference_step,
@ -63,6 +78,16 @@ class DPMSolver(Solver):
return r
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
"""Applies a first-order backward Euler update to the input data `x`.
Args:
x: The input data.
noise: The predicted noise.
step: The current step.
Returns:
The denoised version of the input data `x`.
"""
current_timestep = self.timesteps[step]
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
@ -79,6 +104,15 @@ class DPMSolver(Solver):
return denoised_x
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
"""Applies a second-order backward Euler update to the input data `x`.
Args:
x: The input data.
step: The current step.
Returns:
The denoised version of the input data `x`.
"""
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
current_timestep = self.timesteps[step]
next_timestep = self.timesteps[step - 1]
@ -106,12 +140,21 @@ class DPMSolver(Solver):
return denoised_x
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
"""Apply one step of the backward diffusion process.
This method works by estimating the denoised version of `x` and applying either a first-order or second-order
backward Euler update, which is a numerical method commonly used to solve ordinary differential equations
(ODEs).
Note:
This method works by estimating the denoised version of `x` and applying either a first-order or second-order
backward Euler update, which is a numerical method commonly used to solve ordinary differential equations
(ODEs).
Args:
x: The input data.
predicted_noise: The predicted noise.
step: The current step.
generator: The random number generator.
Returns:
The denoised version of the input data `x`.
"""
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"

View file

@ -6,6 +6,12 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
class Euler(Solver):
"""Euler solver.
See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364)
for more details.
"""
def __init__(
self,
num_inference_steps: int,
@ -17,6 +23,18 @@ class Euler(Solver):
device: Device | str = "cpu",
dtype: Dtype = float32,
):
"""Initializes a new Euler solver.
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule.
first_inference_step: The first inference step.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
if noise_schedule != NoiseSchedule.QUADRATIC:
raise NotImplementedError
super().__init__(
@ -33,23 +51,40 @@ class Euler(Solver):
@property
def init_noise_sigma(self) -> Tensor:
"""The initial noise sigma."""
return self.sigmas.max()
def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase.
"""Generate the timesteps used by the solver.
Note:
We need to use numpy here because:
- numpy.linspace(0,999,31)[15] is 499.49999999999994
- torch.linspace(0,999,31)[15] is 499.5
and we want the same result as the original codebase.
"""
timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0)
return timesteps
def _generate_sigmas(self) -> Tensor:
"""Generate the sigmas used by the solver."""
sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
sigmas = torch.cat([sigmas, tensor([0.0])])
return sigmas.to(device=self.device, dtype=self.dtype)
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
"""Scales the model input according to the current step.
Args:
x: The model input.
step: The current step.
Returns:
The scaled model input.
"""
sigma = self.sigmas[step]
return x / ((sigma**2 + 1) ** 0.5)