- avoid useless multiple assignments
- use coherent variable names
This commit is contained in:
Pierre Chapuis 2024-01-18 14:26:39 +01:00
parent aaddead17d
commit 59db1f0bd5

View file

@ -50,46 +50,43 @@ class DPMSolver(Scheduler):
).flip(0) ).flip(0)
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor: def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
timestep, previous_timestep = ( current_timestep = self.timesteps[step]
self.timesteps[step], previous_timestep = self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0]
self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0],
) previous_ratio = self.signal_to_noise_ratios[previous_timestep]
previous_ratio, current_ratio = ( current_ratio = self.signal_to_noise_ratios[current_timestep]
self.signal_to_noise_ratios[previous_timestep],
self.signal_to_noise_ratios[timestep],
)
previous_scale_factor = self.cumulative_scale_factors[previous_timestep] previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_noise_std, current_noise_std = (
self.noise_std[previous_timestep], previous_noise_std = self.noise_std[previous_timestep]
self.noise_std[timestep], current_noise_std = self.noise_std[current_timestep]
)
factor = exp(-(previous_ratio - current_ratio)) - 1.0 factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
return denoised_x return denoised_x
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor: def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
previous_timestep, current_timestep, next_timestep = ( previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
self.timesteps[step + 1] if step < len(self.timesteps) - 1 else tensor([0]), current_timestep = self.timesteps[step]
self.timesteps[step], next_timestep = self.timesteps[step - 1]
self.timesteps[step - 1],
) current_data_estimation = self.estimated_data[-1]
current_data_estimation, next_data_estimation = self.estimated_data[-1], self.estimated_data[-2] next_data_estimation = self.estimated_data[-2]
previous_ratio, current_ratio, next_ratio = (
self.signal_to_noise_ratios[previous_timestep], previous_ratio = self.signal_to_noise_ratios[previous_timestep]
self.signal_to_noise_ratios[current_timestep], current_ratio = self.signal_to_noise_ratios[current_timestep]
self.signal_to_noise_ratios[next_timestep], next_ratio = self.signal_to_noise_ratios[next_timestep]
)
previous_scale_factor = self.cumulative_scale_factors[previous_timestep] previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_std, current_std = ( previous_noise_std = self.noise_std[previous_timestep]
self.noise_std[previous_timestep], current_noise_std = self.noise_std[current_timestep]
self.noise_std[current_timestep],
)
estimation_delta = (current_data_estimation - next_data_estimation) / ( estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio) (current_ratio - next_ratio) / (previous_ratio - current_ratio)
) )
factor = exp(-(previous_ratio - current_ratio)) - 1.0 factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = ( denoised_x = (
(previous_std / current_std) * x (previous_noise_std / current_noise_std) * x
- (factor * previous_scale_factor) * current_data_estimation - (factor * previous_scale_factor) * current_data_estimation
- 0.5 * (factor * previous_scale_factor) * estimation_delta - 0.5 * (factor * previous_scale_factor) * estimation_delta
) )