From bdba91312b9ce19b36f4b3c1a8c837f436ebe4aa Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Wed, 25 Sep 2024 23:17:58 +0200 Subject: [PATCH] check timesteps stay the same in bfloat16 --- tests/foundationals/latent_diffusion/test_solvers.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/foundationals/latent_diffusion/test_solvers.py b/tests/foundationals/latent_diffusion/test_solvers.py index 577da21..805ec1f 100644 --- a/tests/foundationals/latent_diffusion/test_solvers.py +++ b/tests/foundationals/latent_diffusion/test_solvers.py @@ -320,4 +320,10 @@ def test_dpm_bfloat16(test_device: Device): if test_device.type == "cpu": warn("not running on CPU, skipping") pytest.skip() - DPMSolver(num_inference_steps=5, dtype=torch.bfloat16) # should not raise + + n_steps = 5 + manual_seed(0) + + solver_f32 = DPMSolver(num_inference_steps=n_steps, dtype=torch.float32) + solver_bf16 = DPMSolver(num_inference_steps=n_steps, dtype=torch.bfloat16) + assert torch.equal(solver_bf16.timesteps, solver_f32.timesteps)