DPM: add a mode to use first order for last step

This commit is contained in:
Pierre Chapuis 2024-01-18 11:47:04 +01:00
parent 17d9701dde
commit aaddead17d

View file

@ -7,9 +7,13 @@ from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSc
class DPMSolver(Scheduler):
"""Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
"""
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
We only support noise prediction for now.
Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts
when used with SDXL and few steps. This parameter is a way to mitigate that
effect by using a first-order (Euler) update instead of a second-order update
for the last step of the diffusion.
"""
def __init__(
@ -18,6 +22,7 @@ class DPMSolver(Scheduler):
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
last_step_first_order: bool = False,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: Dtype = float32,
@ -32,7 +37,7 @@ class DPMSolver(Scheduler):
dtype=dtype,
)
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
self.initial_steps = 0
self.last_step_first_order = last_step_first_order
def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because:
@ -102,11 +107,8 @@ class DPMSolver(Scheduler):
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
self.estimated_data.append(estimated_denoised_data)
denoised_x = (
self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
if (self.initial_steps == 0)
else self.multistep_dpm_solver_second_order_update(x=x, step=step)
)
if self.initial_steps < 2:
self.initial_steps += 1
return denoised_x
if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1):
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
return self.multistep_dpm_solver_second_order_update(x=x, step=step)