fix precision of DPM solver in bfloat16

This commit is contained in:
Pierre Chapuis 2024-09-25 21:46:58 +02:00
parent bdba91312b
commit 74ce42f923
No known key found for this signature in database

View file

@ -77,7 +77,7 @@ class DPMSolver(Solver):
first_inference_step=first_inference_step, first_inference_step=first_inference_step,
params=params, params=params,
device=device, device=device,
dtype=dtype, dtype=torch.float64, # compute constants precisely
) )
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2) self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
self.last_step_first_order = last_step_first_order self.last_step_first_order = last_step_first_order
@ -89,6 +89,7 @@ class DPMSolver(Solver):
self.sigmas self.sigmas
) )
self.timesteps = self._timesteps_from_sigmas(sigmas) self.timesteps = self._timesteps_from_sigmas(sigmas)
self.to(dtype=dtype)
def rebuild( def rebuild(
self: "DPMSolver", self: "DPMSolver",
@ -131,12 +132,9 @@ class DPMSolver(Solver):
case NoiseSchedule.KARRAS: case NoiseSchedule.KARRAS:
rho = 7 rho = 7
case None: case None:
if sigmas.dtype == torch.bfloat16:
sigmas = sigmas.to(torch.float32)
return torch.tensor( return torch.tensor(
np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()), np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()),
device=self.device, device=self.device,
dtype=self.dtype,
) )
linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device) linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device)