mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
DPM: add a mode to use first order for last step
This commit is contained in:
parent
17d9701dde
commit
aaddead17d
|
@ -7,9 +7,13 @@ from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSc
|
|||
|
||||
|
||||
class DPMSolver(Scheduler):
|
||||
"""Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
|
||||
"""
|
||||
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
|
||||
|
||||
We only support noise prediction for now.
|
||||
Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts
|
||||
when used with SDXL and few steps. This parameter is a way to mitigate that
|
||||
effect by using a first-order (Euler) update instead of a second-order update
|
||||
for the last step of the diffusion.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -18,6 +22,7 @@ class DPMSolver(Scheduler):
|
|||
num_train_timesteps: int = 1_000,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
last_step_first_order: bool = False,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
|
@ -32,7 +37,7 @@ class DPMSolver(Scheduler):
|
|||
dtype=dtype,
|
||||
)
|
||||
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
|
||||
self.initial_steps = 0
|
||||
self.last_step_first_order = last_step_first_order
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
# We need to use numpy here because:
|
||||
|
@ -102,11 +107,8 @@ class DPMSolver(Scheduler):
|
|||
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
|
||||
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
|
||||
self.estimated_data.append(estimated_denoised_data)
|
||||
denoised_x = (
|
||||
self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
|
||||
if (self.initial_steps == 0)
|
||||
else self.multistep_dpm_solver_second_order_update(x=x, step=step)
|
||||
)
|
||||
if self.initial_steps < 2:
|
||||
self.initial_steps += 1
|
||||
return denoised_x
|
||||
|
||||
if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1):
|
||||
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
|
||||
|
||||
return self.multistep_dpm_solver_second_order_update(x=x, step=step)
|
||||
|
|
Loading…
Reference in a new issue