diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py index a7b2f60..438a964 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py @@ -1,6 +1,6 @@ from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler import numpy as np -from torch import Tensor, device as Device, tensor, exp +from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype from collections import deque @@ -17,6 +17,7 @@ class DPMSolver(Scheduler): initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, device: Device | str = "cpu", + dtype: Dtype = float32, ): super().__init__( num_inference_steps=num_inference_steps, @@ -24,6 +25,7 @@ class DPMSolver(Scheduler): initial_diffusion_rate=initial_diffusion_rate, final_diffusion_rate=final_diffusion_rate, device=device, + dtype=dtype, ) self.estimated_data = deque([tensor([])] * 2, maxlen=2) self.initial_steps = 0 @@ -52,8 +54,8 @@ class DPMSolver(Scheduler): self.noise_std[previous_timestep], self.noise_std[timestep], ) - exp_factor = exp(-(previous_ratio - current_ratio)) - denoised_x = (previous_noise_std / current_noise_std) * x - (previous_scale_factor * (exp_factor - 1.0)) * noise + factor = exp(-(previous_ratio - current_ratio)) - 1.0 + denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise return denoised_x def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor: @@ -76,13 +78,13 @@ class DPMSolver(Scheduler): estimation_delta = (current_data_estimation - next_data_estimation) / ( (current_ratio - next_ratio) / (previous_ratio - current_ratio) ) - exp_neg_factor = exp(-(previous_ratio - current_ratio)) - x_t = ( + factor = exp(-(previous_ratio - current_ratio)) - 1.0 + denoised_x = ( (previous_std / current_std) * x - - (previous_scale_factor * (exp_neg_factor - 1.0)) * current_data_estimation - - 0.5 * (previous_scale_factor * (exp_neg_factor - 1.0)) * estimation_delta + - (factor * previous_scale_factor) * current_data_estimation + - 0.5 * (factor * previous_scale_factor) * estimation_delta ) - return x_t + return denoised_x def __call__( self,