mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
fix precision of DPM solver in bfloat16
This commit is contained in:
parent
bdba91312b
commit
74ce42f923
|
@ -77,7 +77,7 @@ class DPMSolver(Solver):
|
||||||
first_inference_step=first_inference_step,
|
first_inference_step=first_inference_step,
|
||||||
params=params,
|
params=params,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=torch.float64, # compute constants precisely
|
||||||
)
|
)
|
||||||
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
|
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
|
||||||
self.last_step_first_order = last_step_first_order
|
self.last_step_first_order = last_step_first_order
|
||||||
|
@ -89,6 +89,7 @@ class DPMSolver(Solver):
|
||||||
self.sigmas
|
self.sigmas
|
||||||
)
|
)
|
||||||
self.timesteps = self._timesteps_from_sigmas(sigmas)
|
self.timesteps = self._timesteps_from_sigmas(sigmas)
|
||||||
|
self.to(dtype=dtype)
|
||||||
|
|
||||||
def rebuild(
|
def rebuild(
|
||||||
self: "DPMSolver",
|
self: "DPMSolver",
|
||||||
|
@ -131,12 +132,9 @@ class DPMSolver(Solver):
|
||||||
case NoiseSchedule.KARRAS:
|
case NoiseSchedule.KARRAS:
|
||||||
rho = 7
|
rho = 7
|
||||||
case None:
|
case None:
|
||||||
if sigmas.dtype == torch.bfloat16:
|
|
||||||
sigmas = sigmas.to(torch.float32)
|
|
||||||
return torch.tensor(
|
return torch.tensor(
|
||||||
np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()),
|
np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device)
|
linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device)
|
||||||
|
|
Loading…
Reference in a new issue