remove stochastic Euler

It was untested and likely doesn't work.
We will re-introduce it later if needed.
This commit is contained in:
Pierre Chapuis 2024-01-30 19:28:35 +01:00 committed by Cédric Deltheil
parent 8a2b955bd0
commit 12aa0b23f6

View file

@ -59,30 +59,6 @@ class Euler(Solver):
predicted_noise: Tensor, predicted_noise: Tensor,
step: int, step: int,
generator: Generator | None = None, generator: Generator | None = None,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
) -> Tensor: ) -> Tensor:
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step])
sigma = self.sigmas[step]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
noise = torch.randn(
predicted_noise.shape, generator=generator, device=predicted_noise.device, dtype=predicted_noise.dtype
)
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
predicted_x = x - sigma_hat * predicted_noise
# 1st order Euler
derivative = (x - predicted_x) / sigma_hat
dt = self.sigmas[step + 1] - sigma_hat
denoised_x = x + derivative * dt
return denoised_x