mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-12 16:18:22 +00:00
fix precision of DPM solver in bfloat16
This commit is contained in:
parent
bdba91312b
commit
74ce42f923
|
@ -77,7 +77,7 @@ class DPMSolver(Solver):
|
|||
first_inference_step=first_inference_step,
|
||||
params=params,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
dtype=torch.float64, # compute constants precisely
|
||||
)
|
||||
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
|
||||
self.last_step_first_order = last_step_first_order
|
||||
|
@ -89,6 +89,7 @@ class DPMSolver(Solver):
|
|||
self.sigmas
|
||||
)
|
||||
self.timesteps = self._timesteps_from_sigmas(sigmas)
|
||||
self.to(dtype=dtype)
|
||||
|
||||
def rebuild(
|
||||
self: "DPMSolver",
|
||||
|
@ -131,12 +132,9 @@ class DPMSolver(Solver):
|
|||
case NoiseSchedule.KARRAS:
|
||||
rho = 7
|
||||
case None:
|
||||
if sigmas.dtype == torch.bfloat16:
|
||||
sigmas = sigmas.to(torch.float32)
|
||||
return torch.tensor(
|
||||
np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()),
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device)
|
||||
|
|
Loading…
Reference in a new issue