(doc/fluxion/ld) add DDPM, DDIM, DPM++ and Euleur docstrings

This commit is contained in:
Laurent 2024-02-02 12:59:30 +00:00 committed by Laureηt
parent 24c11745cd
commit ff17991261
4 changed files with 132 additions and 26 deletions

View file

@ -4,6 +4,11 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
class DDIM(Solver): class DDIM(Solver):
"""Denoising Diffusion Implicit Model (DDIM) solver.
See [[arXiv:2010.02502] Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) for more details.
"""
def __init__( def __init__(
self, self,
num_inference_steps: int, num_inference_steps: int,
@ -15,6 +20,18 @@ class DDIM(Solver):
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
) -> None: ) -> None:
"""Initializes a new DDIM solver.
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule.
first_inference_step: The first inference step.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps, num_train_timesteps=num_train_timesteps,

View file

@ -1,13 +1,16 @@
from torch import Generator, Tensor, arange, device as Device, dtype as DType from torch import Generator, Tensor, arange, device as Device
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver from refiners.foundationals.latent_diffusion.solvers.solver import Solver
class DDPM(Solver): class DDPM(Solver):
""" """Denoising Diffusion Probabilistic Model (DDPM) solver.
Denoising Diffusion Probabilistic Model
Only used for training Latent Diffusion models. Cannot be called. Warning:
Only used for training Latent Diffusion models.
Cannot be called.
See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) for more details.
""" """
def __init__( def __init__(
@ -16,11 +19,19 @@ class DDPM(Solver):
num_train_timesteps: int = 1_000, num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule | None = None, # ignored
first_inference_step: int = 0, first_inference_step: int = 0,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType | None = None, # ignored
) -> None: ) -> None:
"""Initializes a new DDPM solver.
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
first_inference_step: The first inference step.
device: The PyTorch device to use.
"""
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps, num_train_timesteps=num_train_timesteps,

View file

@ -7,9 +7,12 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
class DPMSolver(Solver): class DPMSolver(Solver):
""" """Diffusion probabilistic models (DPMs) solver.
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
See [[arXiv:2211.01095] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://arxiv.org/abs/2211.01095)
for more details.
Note:
Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts
when used with SDXL and few steps. This parameter is a way to mitigate that when used with SDXL and few steps. This parameter is a way to mitigate that
effect by using a first-order (Euler) update instead of a second-order update effect by using a first-order (Euler) update instead of a second-order update
@ -42,10 +45,16 @@ class DPMSolver(Solver):
self.last_step_first_order = last_step_first_order self.last_step_first_order = last_step_first_order
def _generate_timesteps(self) -> Tensor: def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because: """Generate the timesteps used by the solver.
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5 Note:
# ...and we want the same result as the original codebase. We need to use numpy here because:
- numpy.linspace(0,999,31)[15] is 499.49999999999994
- torch.linspace(0,999,31)[15] is 499.5
and we want the same result as the original codebase.
"""
return tensor( return tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:], np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
).flip(0) ).flip(0)
@ -55,6 +64,12 @@ class DPMSolver(Solver):
num_inference_steps: int | None, num_inference_steps: int | None,
first_inference_step: int | None = None, first_inference_step: int | None = None,
) -> "DPMSolver": ) -> "DPMSolver":
"""Rebuilds the solver with new parameters.
Args:
num_inference_steps: The number of inference steps.
first_inference_step: The first inference step.
"""
r = super().rebuild( r = super().rebuild(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
first_inference_step=first_inference_step, first_inference_step=first_inference_step,
@ -63,6 +78,16 @@ class DPMSolver(Solver):
return r return r
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor: def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
"""Applies a first-order backward Euler update to the input data `x`.
Args:
x: The input data.
noise: The predicted noise.
step: The current step.
Returns:
The denoised version of the input data `x`.
"""
current_timestep = self.timesteps[step] current_timestep = self.timesteps[step]
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0]) previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
@ -79,6 +104,15 @@ class DPMSolver(Solver):
return denoised_x return denoised_x
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor: def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
"""Applies a second-order backward Euler update to the input data `x`.
Args:
x: The input data.
step: The current step.
Returns:
The denoised version of the input data `x`.
"""
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0]) previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
current_timestep = self.timesteps[step] current_timestep = self.timesteps[step]
next_timestep = self.timesteps[step - 1] next_timestep = self.timesteps[step - 1]
@ -106,12 +140,21 @@ class DPMSolver(Solver):
return denoised_x return denoised_x
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
""" """Apply one step of the backward diffusion process.
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
Note:
This method works by estimating the denoised version of `x` and applying either a first-order or second-order This method works by estimating the denoised version of `x` and applying either a first-order or second-order
backward Euler update, which is a numerical method commonly used to solve ordinary differential equations backward Euler update, which is a numerical method commonly used to solve ordinary differential equations
(ODEs). (ODEs).
Args:
x: The input data.
predicted_noise: The predicted noise.
step: The current step.
generator: The random number generator.
Returns:
The denoised version of the input data `x`.
""" """
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"

View file

@ -6,6 +6,12 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
class Euler(Solver): class Euler(Solver):
"""Euler solver.
See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364)
for more details.
"""
def __init__( def __init__(
self, self,
num_inference_steps: int, num_inference_steps: int,
@ -17,6 +23,18 @@ class Euler(Solver):
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
): ):
"""Initializes a new Euler solver.
Args:
num_inference_steps: The number of inference steps.
num_train_timesteps: The number of training timesteps.
initial_diffusion_rate: The initial diffusion rate.
final_diffusion_rate: The final diffusion rate.
noise_schedule: The noise schedule.
first_inference_step: The first inference step.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
if noise_schedule != NoiseSchedule.QUADRATIC: if noise_schedule != NoiseSchedule.QUADRATIC:
raise NotImplementedError raise NotImplementedError
super().__init__( super().__init__(
@ -33,23 +51,40 @@ class Euler(Solver):
@property @property
def init_noise_sigma(self) -> Tensor: def init_noise_sigma(self) -> Tensor:
"""The initial noise sigma."""
return self.sigmas.max() return self.sigmas.max()
def _generate_timesteps(self) -> Tensor: def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because: """Generate the timesteps used by the solver.
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5 Note:
# ...and we want the same result as the original codebase. We need to use numpy here because:
- numpy.linspace(0,999,31)[15] is 499.49999999999994
- torch.linspace(0,999,31)[15] is 499.5
and we want the same result as the original codebase.
"""
timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0) timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0)
return timesteps return timesteps
def _generate_sigmas(self) -> Tensor: def _generate_sigmas(self) -> Tensor:
"""Generate the sigmas used by the solver."""
sigmas = self.noise_std / self.cumulative_scale_factors sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy())) sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
sigmas = torch.cat([sigmas, tensor([0.0])]) sigmas = torch.cat([sigmas, tensor([0.0])])
return sigmas.to(device=self.device, dtype=self.dtype) return sigmas.to(device=self.device, dtype=self.dtype)
def scale_model_input(self, x: Tensor, step: int) -> Tensor: def scale_model_input(self, x: Tensor, step: int) -> Tensor:
"""Scales the model input according to the current step.
Args:
x: The model input.
step: The current step.
Returns:
The scaled model input.
"""
sigma = self.sigmas[step] sigma = self.sigmas[step]
return x / ((sigma**2 + 1) ** 0.5) return x / ((sigma**2 + 1) ** 0.5)