From 7fecf0298dacb392d73091147df3ad3fd37696c6 Mon Sep 17 00:00:00 2001 From: Laurent Date: Mon, 9 Sep 2024 14:32:38 +0000 Subject: [PATCH] implement _add_noise for dpm solver --- .../latent_diffusion/solvers/dpm.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py index 08d56c4..71f2175 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py @@ -174,7 +174,30 @@ class DPMSolver(Solver): timestep = (1 - interpolation_weights) * low_indices + interpolation_weights * high_indices timesteps.append(timestep) - return torch.cat(timesteps).round() + return torch.cat(timesteps).round().int() + + def _add_noise( + self, + x: torch.Tensor, + noise: torch.Tensor, + step: int, + ) -> torch.Tensor: + """Add noise to the input tensor using the solver's parameters. + + Args: + x: The input tensor to add noise to. + noise: The noise tensor to add to the input tensor. + step: The current step of the diffusion process. + + Returns: + The input tensor with added noise. + """ + cumulative_scale_factors = self.cumulative_scale_factors[step] + noise_stds = self.noise_std[step] + + # noisify the latents, arXiv:2006.11239 Eq. 4 + noised_x = cumulative_scale_factors * x + noise_stds * noise + return noised_x def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors: """Generate the tensors from the sigmas."""