fix bug in dpm_solver_first_order_update

This commit is contained in:
Pierre Chapuis 2024-01-18 14:27:30 +01:00
parent 59db1f0bd5
commit 999e429697

View file

@ -51,7 +51,7 @@ class DPMSolver(Scheduler):
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
current_timestep = self.timesteps[step]
previous_timestep = self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0]
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
current_ratio = self.signal_to_noise_ratios[current_timestep]