diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py index e570898..3999a41 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py @@ -50,46 +50,43 @@ class DPMSolver(Scheduler): ).flip(0) def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor: - timestep, previous_timestep = ( - self.timesteps[step], - self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0], - ) - previous_ratio, current_ratio = ( - self.signal_to_noise_ratios[previous_timestep], - self.signal_to_noise_ratios[timestep], - ) + current_timestep = self.timesteps[step] + previous_timestep = self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0] + + previous_ratio = self.signal_to_noise_ratios[previous_timestep] + current_ratio = self.signal_to_noise_ratios[current_timestep] + previous_scale_factor = self.cumulative_scale_factors[previous_timestep] - previous_noise_std, current_noise_std = ( - self.noise_std[previous_timestep], - self.noise_std[timestep], - ) + + previous_noise_std = self.noise_std[previous_timestep] + current_noise_std = self.noise_std[current_timestep] + factor = exp(-(previous_ratio - current_ratio)) - 1.0 denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise return denoised_x def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor: - previous_timestep, current_timestep, next_timestep = ( - self.timesteps[step + 1] if step < len(self.timesteps) - 1 else tensor([0]), - self.timesteps[step], - self.timesteps[step - 1], - ) - current_data_estimation, next_data_estimation = self.estimated_data[-1], self.estimated_data[-2] - previous_ratio, current_ratio, next_ratio = ( - self.signal_to_noise_ratios[previous_timestep], - self.signal_to_noise_ratios[current_timestep], - self.signal_to_noise_ratios[next_timestep], - ) + previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0]) + current_timestep = self.timesteps[step] + next_timestep = self.timesteps[step - 1] + + current_data_estimation = self.estimated_data[-1] + next_data_estimation = self.estimated_data[-2] + + previous_ratio = self.signal_to_noise_ratios[previous_timestep] + current_ratio = self.signal_to_noise_ratios[current_timestep] + next_ratio = self.signal_to_noise_ratios[next_timestep] + previous_scale_factor = self.cumulative_scale_factors[previous_timestep] - previous_std, current_std = ( - self.noise_std[previous_timestep], - self.noise_std[current_timestep], - ) + previous_noise_std = self.noise_std[previous_timestep] + current_noise_std = self.noise_std[current_timestep] + estimation_delta = (current_data_estimation - next_data_estimation) / ( (current_ratio - next_ratio) / (previous_ratio - current_ratio) ) factor = exp(-(previous_ratio - current_ratio)) - 1.0 denoised_x = ( - (previous_std / current_std) * x + (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * current_data_estimation - 0.5 * (factor * previous_scale_factor) * estimation_delta )