mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
improve consistency of the dpm scheduler
This commit is contained in:
parent
7a62049d54
commit
585c7ad55a
|
@ -1,6 +1,6 @@
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import Tensor, device as Device, tensor, exp
|
from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ class DPMSolver(Scheduler):
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
|
dtype: Dtype = float32,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
|
@ -24,6 +25,7 @@ class DPMSolver(Scheduler):
|
||||||
initial_diffusion_rate=initial_diffusion_rate,
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
final_diffusion_rate=final_diffusion_rate,
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
device=device,
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
|
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
|
||||||
self.initial_steps = 0
|
self.initial_steps = 0
|
||||||
|
@ -52,8 +54,8 @@ class DPMSolver(Scheduler):
|
||||||
self.noise_std[previous_timestep],
|
self.noise_std[previous_timestep],
|
||||||
self.noise_std[timestep],
|
self.noise_std[timestep],
|
||||||
)
|
)
|
||||||
exp_factor = exp(-(previous_ratio - current_ratio))
|
factor = exp(-(previous_ratio - current_ratio)) - 1.0
|
||||||
denoised_x = (previous_noise_std / current_noise_std) * x - (previous_scale_factor * (exp_factor - 1.0)) * noise
|
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
|
||||||
return denoised_x
|
return denoised_x
|
||||||
|
|
||||||
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
|
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
|
||||||
|
@ -76,13 +78,13 @@ class DPMSolver(Scheduler):
|
||||||
estimation_delta = (current_data_estimation - next_data_estimation) / (
|
estimation_delta = (current_data_estimation - next_data_estimation) / (
|
||||||
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
|
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
|
||||||
)
|
)
|
||||||
exp_neg_factor = exp(-(previous_ratio - current_ratio))
|
factor = exp(-(previous_ratio - current_ratio)) - 1.0
|
||||||
x_t = (
|
denoised_x = (
|
||||||
(previous_std / current_std) * x
|
(previous_std / current_std) * x
|
||||||
- (previous_scale_factor * (exp_neg_factor - 1.0)) * current_data_estimation
|
- (factor * previous_scale_factor) * current_data_estimation
|
||||||
- 0.5 * (previous_scale_factor * (exp_neg_factor - 1.0)) * estimation_delta
|
- 0.5 * (factor * previous_scale_factor) * estimation_delta
|
||||||
)
|
)
|
||||||
return x_t
|
return denoised_x
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in a new issue