use default (1e-8) atol for basic DPM solver (used in SD1.5)

This commit is contained in:
Pierre Chapuis 2024-09-25 21:11:38 +02:00
parent 83b931296f
commit 5741c13f28
No known key found for this signature in database

View file

@ -73,7 +73,12 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool, sde_var
manual_seed(37)
refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)]
atol = 1e-4 if use_karras_sigmas else 1e-6
if use_karras_sigmas:
atol = 1e-4
elif sde_variance == 1.0:
atol = 1e-6
else:
atol = 1e-8
for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)):
assert torch.allclose(diffusers_output, refiners_output, rtol=0.01, atol=atol), f"outputs differ at step {step}"