add missing remove_noise method from the DPMSolver class

This commit is contained in:
Laurent 2024-10-01 08:12:51 +00:00 committed by Laureηt
parent 8bd405661d
commit 18ef5684c0

View file

@ -191,6 +191,16 @@ class DPMSolver(Solver):
noised_x = cumulative_scale_factors * x + noise_stds * noise
return noised_x
def remove_noise(self, x: torch.Tensor, noise: torch.Tensor, step: int) -> torch.Tensor:
"""Remove noise from the input tensor using the current step of the diffusion process.
See [`Solver.remove_noise`][refiners.foundationals.latent_diffusion.solvers.solver.Solver.remove_noise] for more details.
"""
cumulative_scale_factors = self.cumulative_scale_factors[step]
noise_stds = self.noise_std[step]
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
return denoised_x
def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors:
"""Generate the tensors from the sigmas."""
cumulative_scale_factors = 1 / torch.sqrt(sigmas**2 + 1)