implement _add_noise for dpm solver
Some checks failed
Spell checker / Spell check (push) Has been cancelled
CI / lint_and_typecheck (push) Has been cancelled

This commit is contained in:
Laurent 2024-09-09 14:32:38 +00:00 committed by Laureηt
parent 6d58492097
commit 7fecf0298d

View file

@ -174,7 +174,30 @@ class DPMSolver(Solver):
timestep = (1 - interpolation_weights) * low_indices + interpolation_weights * high_indices
timesteps.append(timestep)
return torch.cat(timesteps).round()
return torch.cat(timesteps).round().int()
def _add_noise(
self,
x: torch.Tensor,
noise: torch.Tensor,
step: int,
) -> torch.Tensor:
"""Add noise to the input tensor using the solver's parameters.
Args:
x: The input tensor to add noise to.
noise: The noise tensor to add to the input tensor.
step: The current step of the diffusion process.
Returns:
The input tensor with added noise.
"""
cumulative_scale_factors = self.cumulative_scale_factors[step]
noise_stds = self.noise_std[step]
# noisify the latents, arXiv:2006.11239 Eq. 4
noised_x = cumulative_scale_factors * x + noise_stds * noise
return noised_x
def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors:
"""Generate the tensors from the sigmas."""