implement _add_noise for dpm solver

This commit is contained in:
Laurent 2024-09-09 14:32:38 +00:00
parent a51d695523
commit 5ca6ca2f52
No known key found for this signature in database

View file

@ -174,7 +174,30 @@ class DPMSolver(Solver):
timestep = (1 - interpolation_weights) * low_indices + interpolation_weights * high_indices timestep = (1 - interpolation_weights) * low_indices + interpolation_weights * high_indices
timesteps.append(timestep) timesteps.append(timestep)
return torch.cat(timesteps).round() return torch.cat(timesteps).round().int()
def _add_noise(
self,
x: torch.Tensor,
noise: torch.Tensor,
step: int,
) -> torch.Tensor:
"""Add noise to the input tensor using the solver's parameters.
Args:
x: The input tensor to add noise to.
noise: The noise tensor to add to the input tensor.
step: The current step of the diffusion process.
Returns:
The input tensor with added noise.
"""
cumulative_scale_factors = self.cumulative_scale_factors[step]
noise_stds = self.noise_std[step]
# noisify the latents, arXiv:2006.11239 Eq. 4
noised_x = cumulative_scale_factors * x + noise_stds * noise
return noised_x
def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors: def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors:
"""Generate the tensors from the sigmas.""" """Generate the tensors from the sigmas."""