From aaddead17d9f9c32defdec28705890026e92388b Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Thu, 18 Jan 2024 11:47:04 +0100 Subject: [PATCH] DPM: add a mode to use first order for last step --- .../latent_diffusion/schedulers/dpm_solver.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py index 52e706c..e570898 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py @@ -7,9 +7,13 @@ from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSc class DPMSolver(Scheduler): - """Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095 + """ + Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095 - We only support noise prediction for now. + Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts + when used with SDXL and few steps. This parameter is a way to mitigate that + effect by using a first-order (Euler) update instead of a second-order update + for the last step of the diffusion. """ def __init__( @@ -18,6 +22,7 @@ class DPMSolver(Scheduler): num_train_timesteps: int = 1_000, initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, + last_step_first_order: bool = False, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, device: Device | str = "cpu", dtype: Dtype = float32, @@ -32,7 +37,7 @@ class DPMSolver(Scheduler): dtype=dtype, ) self.estimated_data = deque([tensor([])] * 2, maxlen=2) - self.initial_steps = 0 + self.last_step_first_order = last_step_first_order def _generate_timesteps(self) -> Tensor: # We need to use numpy here because: @@ -102,11 +107,8 @@ class DPMSolver(Scheduler): scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep] estimated_denoised_data = (x - noise_ratio * noise) / scale_factor self.estimated_data.append(estimated_denoised_data) - denoised_x = ( - self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step) - if (self.initial_steps == 0) - else self.multistep_dpm_solver_second_order_update(x=x, step=step) - ) - if self.initial_steps < 2: - self.initial_steps += 1 - return denoised_x + + if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1): + return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step) + + return self.multistep_dpm_solver_second_order_update(x=x, step=step)