check timesteps stay the same in bfloat16

This commit is contained in:
Pierre Chapuis 2024-09-25 23:17:58 +02:00
parent 5741c13f28
commit bdba91312b
No known key found for this signature in database

View file

@ -320,4 +320,10 @@ def test_dpm_bfloat16(test_device: Device):
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
DPMSolver(num_inference_steps=5, dtype=torch.bfloat16) # should not raise
n_steps = 5
manual_seed(0)
solver_f32 = DPMSolver(num_inference_steps=n_steps, dtype=torch.float32)
solver_bf16 = DPMSolver(num_inference_steps=n_steps, dtype=torch.bfloat16)
assert torch.equal(solver_bf16.timesteps, solver_f32.timesteps)