From fc7b4dd62d223b14a9d1f4fb0ae8ad9897a773a3 Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 2 Feb 2024 12:59:30 +0000 Subject: [PATCH] (doc/fluxion/ld) add `DDPM`, `DDIM`, `DPM++` and `Euleur` docstrings --- .../latent_diffusion/solvers/ddim.py | 17 +++++ .../latent_diffusion/solvers/ddpm.py | 25 +++++-- .../latent_diffusion/solvers/dpm.py | 73 +++++++++++++++---- .../latent_diffusion/solvers/euler.py | 43 ++++++++++- 4 files changed, 132 insertions(+), 26 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/solvers/ddim.py b/src/refiners/foundationals/latent_diffusion/solvers/ddim.py index 8e02d45..141d088 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/ddim.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/ddim.py @@ -4,6 +4,11 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule class DDIM(Solver): + """Denoising Diffusion Implicit Model (DDIM) solver. + + See [[arXiv:2010.02502] Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) for more details. + """ + def __init__( self, num_inference_steps: int, @@ -15,6 +20,18 @@ class DDIM(Solver): device: Device | str = "cpu", dtype: Dtype = float32, ) -> None: + """Initializes a new DDIM solver. + + Args: + num_inference_steps: The number of inference steps. + num_train_timesteps: The number of training timesteps. + initial_diffusion_rate: The initial diffusion rate. + final_diffusion_rate: The final diffusion rate. + noise_schedule: The noise schedule. + first_inference_step: The first inference step. + device: The PyTorch device to use. + dtype: The PyTorch data type to use. + """ super().__init__( num_inference_steps=num_inference_steps, num_train_timesteps=num_train_timesteps, diff --git a/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py b/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py index 31cb52b..49ae9c2 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py @@ -1,13 +1,16 @@ -from torch import Generator, Tensor, arange, device as Device, dtype as DType +from torch import Generator, Tensor, arange, device as Device -from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver +from refiners.foundationals.latent_diffusion.solvers.solver import Solver class DDPM(Solver): - """ - Denoising Diffusion Probabilistic Model + """Denoising Diffusion Probabilistic Model (DDPM) solver. - Only used for training Latent Diffusion models. Cannot be called. + Warning: + Only used for training Latent Diffusion models. + Cannot be called. + + See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) for more details. """ def __init__( @@ -16,11 +19,19 @@ class DDPM(Solver): num_train_timesteps: int = 1_000, initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, - noise_schedule: NoiseSchedule | None = None, # ignored first_inference_step: int = 0, device: Device | str = "cpu", - dtype: DType | None = None, # ignored ) -> None: + """Initializes a new DDPM solver. + + Args: + num_inference_steps: The number of inference steps. + num_train_timesteps: The number of training timesteps. + initial_diffusion_rate: The initial diffusion rate. + final_diffusion_rate: The final diffusion rate. + first_inference_step: The first inference step. + device: The PyTorch device to use. + """ super().__init__( num_inference_steps=num_inference_steps, num_train_timesteps=num_train_timesteps, diff --git a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py index ca28913..64351a7 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py @@ -7,13 +7,16 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule class DPMSolver(Solver): - """ - Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095 + """Diffusion probabilistic models (DPMs) solver. - Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts - when used with SDXL and few steps. This parameter is a way to mitigate that - effect by using a first-order (Euler) update instead of a second-order update - for the last step of the diffusion. + See [[arXiv:2211.01095] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://arxiv.org/abs/2211.01095) + for more details. + + Note: + Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts + when used with SDXL and few steps. This parameter is a way to mitigate that + effect by using a first-order (Euler) update instead of a second-order update + for the last step of the diffusion. """ def __init__( @@ -42,10 +45,16 @@ class DPMSolver(Solver): self.last_step_first_order = last_step_first_order def _generate_timesteps(self) -> Tensor: - # We need to use numpy here because: - # numpy.linspace(0,999,31)[15] is 499.49999999999994 - # torch.linspace(0,999,31)[15] is 499.5 - # ...and we want the same result as the original codebase. + """Generate the timesteps used by the solver. + + Note: + We need to use numpy here because: + + - numpy.linspace(0,999,31)[15] is 499.49999999999994 + - torch.linspace(0,999,31)[15] is 499.5 + + and we want the same result as the original codebase. + """ return tensor( np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:], ).flip(0) @@ -55,6 +64,12 @@ class DPMSolver(Solver): num_inference_steps: int | None, first_inference_step: int | None = None, ) -> "DPMSolver": + """Rebuilds the solver with new parameters. + + Args: + num_inference_steps: The number of inference steps. + first_inference_step: The first inference step. + """ r = super().rebuild( num_inference_steps=num_inference_steps, first_inference_step=first_inference_step, @@ -63,6 +78,16 @@ class DPMSolver(Solver): return r def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor: + """Applies a first-order backward Euler update to the input data `x`. + + Args: + x: The input data. + noise: The predicted noise. + step: The current step. + + Returns: + The denoised version of the input data `x`. + """ current_timestep = self.timesteps[step] previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0]) @@ -79,6 +104,15 @@ class DPMSolver(Solver): return denoised_x def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor: + """Applies a second-order backward Euler update to the input data `x`. + + Args: + x: The input data. + step: The current step. + + Returns: + The denoised version of the input data `x`. + """ previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0]) current_timestep = self.timesteps[step] next_timestep = self.timesteps[step - 1] @@ -106,12 +140,21 @@ class DPMSolver(Solver): return denoised_x def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: - """ - Represents one step of the backward diffusion process that iteratively denoises the input data `x`. + """Apply one step of the backward diffusion process. - This method works by estimating the denoised version of `x` and applying either a first-order or second-order - backward Euler update, which is a numerical method commonly used to solve ordinary differential equations - (ODEs). + Note: + This method works by estimating the denoised version of `x` and applying either a first-order or second-order + backward Euler update, which is a numerical method commonly used to solve ordinary differential equations + (ODEs). + + Args: + x: The input data. + predicted_noise: The predicted noise. + step: The current step. + generator: The random number generator. + + Returns: + The denoised version of the input data `x`. """ assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" diff --git a/src/refiners/foundationals/latent_diffusion/solvers/euler.py b/src/refiners/foundationals/latent_diffusion/solvers/euler.py index 6643081..cef8897 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/euler.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/euler.py @@ -6,6 +6,12 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule class Euler(Solver): + """Euler solver. + + See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) + for more details. + """ + def __init__( self, num_inference_steps: int, @@ -17,6 +23,18 @@ class Euler(Solver): device: Device | str = "cpu", dtype: Dtype = float32, ): + """Initializes a new Euler solver. + + Args: + num_inference_steps: The number of inference steps. + num_train_timesteps: The number of training timesteps. + initial_diffusion_rate: The initial diffusion rate. + final_diffusion_rate: The final diffusion rate. + noise_schedule: The noise schedule. + first_inference_step: The first inference step. + device: The PyTorch device to use. + dtype: The PyTorch data type to use. + """ if noise_schedule != NoiseSchedule.QUADRATIC: raise NotImplementedError super().__init__( @@ -33,23 +51,40 @@ class Euler(Solver): @property def init_noise_sigma(self) -> Tensor: + """The initial noise sigma.""" return self.sigmas.max() def _generate_timesteps(self) -> Tensor: - # We need to use numpy here because: - # numpy.linspace(0,999,31)[15] is 499.49999999999994 - # torch.linspace(0,999,31)[15] is 499.5 - # ...and we want the same result as the original codebase. + """Generate the timesteps used by the solver. + + Note: + We need to use numpy here because: + + - numpy.linspace(0,999,31)[15] is 499.49999999999994 + - torch.linspace(0,999,31)[15] is 499.5 + + and we want the same result as the original codebase. + """ timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0) return timesteps def _generate_sigmas(self) -> Tensor: + """Generate the sigmas used by the solver.""" sigmas = self.noise_std / self.cumulative_scale_factors sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy())) sigmas = torch.cat([sigmas, tensor([0.0])]) return sigmas.to(device=self.device, dtype=self.dtype) def scale_model_input(self, x: Tensor, step: int) -> Tensor: + """Scales the model input according to the current step. + + Args: + x: The model input. + step: The current step. + + Returns: + The scaled model input. + """ sigma = self.sigmas[step] return x / ((sigma**2 + 1) ** 0.5)