improve consistency of the dpm scheduler

This commit is contained in:
limiteinductive 2023-10-12 15:37:56 +02:00 committed by Benjamin Trom
parent 7a62049d54
commit 585c7ad55a

View file

@ -1,6 +1,6 @@
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
import numpy as np import numpy as np
from torch import Tensor, device as Device, tensor, exp from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype
from collections import deque from collections import deque
@ -17,6 +17,7 @@ class DPMSolver(Scheduler):
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32,
): ):
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
@ -24,6 +25,7 @@ class DPMSolver(Scheduler):
initial_diffusion_rate=initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
device=device, device=device,
dtype=dtype,
) )
self.estimated_data = deque([tensor([])] * 2, maxlen=2) self.estimated_data = deque([tensor([])] * 2, maxlen=2)
self.initial_steps = 0 self.initial_steps = 0
@ -52,8 +54,8 @@ class DPMSolver(Scheduler):
self.noise_std[previous_timestep], self.noise_std[previous_timestep],
self.noise_std[timestep], self.noise_std[timestep],
) )
exp_factor = exp(-(previous_ratio - current_ratio)) factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (previous_noise_std / current_noise_std) * x - (previous_scale_factor * (exp_factor - 1.0)) * noise denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
return denoised_x return denoised_x
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor: def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
@ -76,13 +78,13 @@ class DPMSolver(Scheduler):
estimation_delta = (current_data_estimation - next_data_estimation) / ( estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio) (current_ratio - next_ratio) / (previous_ratio - current_ratio)
) )
exp_neg_factor = exp(-(previous_ratio - current_ratio)) factor = exp(-(previous_ratio - current_ratio)) - 1.0
x_t = ( denoised_x = (
(previous_std / current_std) * x (previous_std / current_std) * x
- (previous_scale_factor * (exp_neg_factor - 1.0)) * current_data_estimation - (factor * previous_scale_factor) * current_data_estimation
- 0.5 * (previous_scale_factor * (exp_neg_factor - 1.0)) * estimation_delta - 0.5 * (factor * previous_scale_factor) * estimation_delta
) )
return x_t return denoised_x
def __call__( def __call__(
self, self,