mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
improve consistency of the dpm scheduler
This commit is contained in:
parent
7a62049d54
commit
585c7ad55a
|
@ -1,6 +1,6 @@
|
|||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
import numpy as np
|
||||
from torch import Tensor, device as Device, tensor, exp
|
||||
from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype
|
||||
from collections import deque
|
||||
|
||||
|
||||
|
@ -17,6 +17,7 @@ class DPMSolver(Scheduler):
|
|||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
):
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
|
@ -24,6 +25,7 @@ class DPMSolver(Scheduler):
|
|||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
|
||||
self.initial_steps = 0
|
||||
|
@ -52,8 +54,8 @@ class DPMSolver(Scheduler):
|
|||
self.noise_std[previous_timestep],
|
||||
self.noise_std[timestep],
|
||||
)
|
||||
exp_factor = exp(-(previous_ratio - current_ratio))
|
||||
denoised_x = (previous_noise_std / current_noise_std) * x - (previous_scale_factor * (exp_factor - 1.0)) * noise
|
||||
factor = exp(-(previous_ratio - current_ratio)) - 1.0
|
||||
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
|
||||
return denoised_x
|
||||
|
||||
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
|
||||
|
@ -76,13 +78,13 @@ class DPMSolver(Scheduler):
|
|||
estimation_delta = (current_data_estimation - next_data_estimation) / (
|
||||
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
|
||||
)
|
||||
exp_neg_factor = exp(-(previous_ratio - current_ratio))
|
||||
x_t = (
|
||||
factor = exp(-(previous_ratio - current_ratio)) - 1.0
|
||||
denoised_x = (
|
||||
(previous_std / current_std) * x
|
||||
- (previous_scale_factor * (exp_neg_factor - 1.0)) * current_data_estimation
|
||||
- 0.5 * (previous_scale_factor * (exp_neg_factor - 1.0)) * estimation_delta
|
||||
- (factor * previous_scale_factor) * current_data_estimation
|
||||
- 0.5 * (factor * previous_scale_factor) * estimation_delta
|
||||
)
|
||||
return x_t
|
||||
return denoised_x
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
|
Loading…
Reference in a new issue