diff --git a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py index 2ec2de4..b2296dc 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py @@ -191,6 +191,16 @@ class DPMSolver(Solver): noised_x = cumulative_scale_factors * x + noise_stds * noise return noised_x + def remove_noise(self, x: torch.Tensor, noise: torch.Tensor, step: int) -> torch.Tensor: + """Remove noise from the input tensor using the current step of the diffusion process. + + See [`Solver.remove_noise`][refiners.foundationals.latent_diffusion.solvers.solver.Solver.remove_noise] for more details. + """ + cumulative_scale_factors = self.cumulative_scale_factors[step] + noise_stds = self.noise_std[step] + denoised_x = (x - noise_stds * noise) / cumulative_scale_factors + return denoised_x + def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors: """Generate the tensors from the sigmas.""" cumulative_scale_factors = 1 / torch.sqrt(sigmas**2 + 1)