From 74ce42f9237728de9cc4bcb8c90c6da808dc5e0b Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Wed, 25 Sep 2024 21:46:58 +0200 Subject: [PATCH] fix precision of DPM solver in bfloat16 --- src/refiners/foundationals/latent_diffusion/solvers/dpm.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py index a9d1c16..0ff2ed0 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py @@ -77,7 +77,7 @@ class DPMSolver(Solver): first_inference_step=first_inference_step, params=params, device=device, - dtype=dtype, + dtype=torch.float64, # compute constants precisely ) self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2) self.last_step_first_order = last_step_first_order @@ -89,6 +89,7 @@ class DPMSolver(Solver): self.sigmas ) self.timesteps = self._timesteps_from_sigmas(sigmas) + self.to(dtype=dtype) def rebuild( self: "DPMSolver", @@ -131,12 +132,9 @@ class DPMSolver(Solver): case NoiseSchedule.KARRAS: rho = 7 case None: - if sigmas.dtype == torch.bfloat16: - sigmas = sigmas.to(torch.float32) return torch.tensor( np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()), device=self.device, - dtype=self.dtype, ) linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device)