improve consistency of the dpm scheduler

This commit is contained in:
limiteinductive 2023-10-12 15:37:56 +02:00 committed by Benjamin Trom
parent 7a62049d54
commit 585c7ad55a

View file

@ -1,6 +1,6 @@
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
import numpy as np
from torch import Tensor, device as Device, tensor, exp
from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype
from collections import deque
@ -17,6 +17,7 @@ class DPMSolver(Scheduler):
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
device: Device | str = "cpu",
dtype: Dtype = float32,
):
super().__init__(
num_inference_steps=num_inference_steps,
@ -24,6 +25,7 @@ class DPMSolver(Scheduler):
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
device=device,
dtype=dtype,
)
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
self.initial_steps = 0
@ -52,8 +54,8 @@ class DPMSolver(Scheduler):
self.noise_std[previous_timestep],
self.noise_std[timestep],
)
exp_factor = exp(-(previous_ratio - current_ratio))
denoised_x = (previous_noise_std / current_noise_std) * x - (previous_scale_factor * (exp_factor - 1.0)) * noise
factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
return denoised_x
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
@ -76,13 +78,13 @@ class DPMSolver(Scheduler):
estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
)
exp_neg_factor = exp(-(previous_ratio - current_ratio))
x_t = (
factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (
(previous_std / current_std) * x
- (previous_scale_factor * (exp_neg_factor - 1.0)) * current_data_estimation
- 0.5 * (previous_scale_factor * (exp_neg_factor - 1.0)) * estimation_delta
- (factor * previous_scale_factor) * current_data_estimation
- 0.5 * (factor * previous_scale_factor) * estimation_delta
)
return x_t
return denoised_x
def __call__(
self,