diff --git a/src/refiners/foundationals/latent_diffusion/solvers/euler.py b/src/refiners/foundationals/latent_diffusion/solvers/euler.py index 17711cc..6643081 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/euler.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/euler.py @@ -59,30 +59,6 @@ class Euler(Solver): predicted_noise: Tensor, step: int, generator: Generator | None = None, - s_churn: float = 0.0, - s_tmin: float = 0.0, - s_tmax: float = float("inf"), - s_noise: float = 1.0, ) -> Tensor: assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" - - sigma = self.sigmas[step] - - gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0 - - noise = torch.randn( - predicted_noise.shape, generator=generator, device=predicted_noise.device, dtype=predicted_noise.dtype - ) - eps = noise * s_noise - sigma_hat = sigma * (gamma + 1) - if gamma > 0: - x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5 - - predicted_x = x - sigma_hat * predicted_noise - - # 1st order Euler - derivative = (x - predicted_x) / sigma_hat - dt = self.sigmas[step + 1] - sigma_hat - denoised_x = x + derivative * dt - - return denoised_x + return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step])