add missing remove_noise method from the DPMSolver class

This commit is contained in:
Laurent 2024-10-01 08:12:51 +00:00
parent b6ceecf16a
commit 1e6a3f1d7c
No known key found for this signature in database

View file

@ -191,6 +191,16 @@ class DPMSolver(Solver):
noised_x = cumulative_scale_factors * x + noise_stds * noise noised_x = cumulative_scale_factors * x + noise_stds * noise
return noised_x return noised_x
def remove_noise(self, x: torch.Tensor, noise: torch.Tensor, step: int) -> torch.Tensor:
"""Remove noise from the input tensor using the current step of the diffusion process.
See [`Solver.remove_noise`][refiners.foundationals.latent_diffusion.solvers.solver.Solver.remove_noise] for more details.
"""
cumulative_scale_factors = self.cumulative_scale_factors[step]
noise_stds = self.noise_std[step]
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
return denoised_x
def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors: def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors:
"""Generate the tensors from the sigmas.""" """Generate the tensors from the sigmas."""
cumulative_scale_factors = 1 / torch.sqrt(sigmas**2 + 1) cumulative_scale_factors = 1 / torch.sqrt(sigmas**2 + 1)