mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add missing remove_noise method from the DPMSolver class
This commit is contained in:
parent
b6ceecf16a
commit
1e6a3f1d7c
|
@ -191,6 +191,16 @@ class DPMSolver(Solver):
|
||||||
noised_x = cumulative_scale_factors * x + noise_stds * noise
|
noised_x = cumulative_scale_factors * x + noise_stds * noise
|
||||||
return noised_x
|
return noised_x
|
||||||
|
|
||||||
|
def remove_noise(self, x: torch.Tensor, noise: torch.Tensor, step: int) -> torch.Tensor:
|
||||||
|
"""Remove noise from the input tensor using the current step of the diffusion process.
|
||||||
|
|
||||||
|
See [`Solver.remove_noise`][refiners.foundationals.latent_diffusion.solvers.solver.Solver.remove_noise] for more details.
|
||||||
|
"""
|
||||||
|
cumulative_scale_factors = self.cumulative_scale_factors[step]
|
||||||
|
noise_stds = self.noise_std[step]
|
||||||
|
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
|
||||||
|
return denoised_x
|
||||||
|
|
||||||
def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors:
|
def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors:
|
||||||
"""Generate the tensors from the sigmas."""
|
"""Generate the tensors from the sigmas."""
|
||||||
cumulative_scale_factors = 1 / torch.sqrt(sigmas**2 + 1)
|
cumulative_scale_factors = 1 / torch.sqrt(sigmas**2 + 1)
|
||||||
|
|
Loading…
Reference in a new issue