mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 17:18:14 +00:00
check timesteps stay the same in bfloat16
This commit is contained in:
parent
5741c13f28
commit
bdba91312b
|
@ -320,4 +320,10 @@ def test_dpm_bfloat16(test_device: Device):
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
DPMSolver(num_inference_steps=5, dtype=torch.bfloat16) # should not raise
|
|
||||||
|
n_steps = 5
|
||||||
|
manual_seed(0)
|
||||||
|
|
||||||
|
solver_f32 = DPMSolver(num_inference_steps=n_steps, dtype=torch.float32)
|
||||||
|
solver_bf16 = DPMSolver(num_inference_steps=n_steps, dtype=torch.bfloat16)
|
||||||
|
assert torch.equal(solver_bf16.timesteps, solver_f32.timesteps)
|
||||||
|
|
Loading…
Reference in a new issue