diff --git a/tests/foundationals/latent_diffusion/test_solvers.py b/tests/foundationals/latent_diffusion/test_solvers.py index 0066174..577da21 100644 --- a/tests/foundationals/latent_diffusion/test_solvers.py +++ b/tests/foundationals/latent_diffusion/test_solvers.py @@ -73,7 +73,12 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool, sde_var manual_seed(37) refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)] - atol = 1e-4 if use_karras_sigmas else 1e-6 + if use_karras_sigmas: + atol = 1e-4 + elif sde_variance == 1.0: + atol = 1e-6 + else: + atol = 1e-8 for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)): assert torch.allclose(diffusers_output, refiners_output, rtol=0.01, atol=atol), f"outputs differ at step {step}"