mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
DPM: add a mode to use first order for last step
This commit is contained in:
parent
17d9701dde
commit
aaddead17d
|
@ -7,9 +7,13 @@ from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSc
|
||||||
|
|
||||||
|
|
||||||
class DPMSolver(Scheduler):
|
class DPMSolver(Scheduler):
|
||||||
"""Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
|
"""
|
||||||
|
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
|
||||||
|
|
||||||
We only support noise prediction for now.
|
Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts
|
||||||
|
when used with SDXL and few steps. This parameter is a way to mitigate that
|
||||||
|
effect by using a first-order (Euler) update instead of a second-order update
|
||||||
|
for the last step of the diffusion.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -18,6 +22,7 @@ class DPMSolver(Scheduler):
|
||||||
num_train_timesteps: int = 1_000,
|
num_train_timesteps: int = 1_000,
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
|
last_step_first_order: bool = False,
|
||||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = float32,
|
||||||
|
@ -32,7 +37,7 @@ class DPMSolver(Scheduler):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
|
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
|
||||||
self.initial_steps = 0
|
self.last_step_first_order = last_step_first_order
|
||||||
|
|
||||||
def _generate_timesteps(self) -> Tensor:
|
def _generate_timesteps(self) -> Tensor:
|
||||||
# We need to use numpy here because:
|
# We need to use numpy here because:
|
||||||
|
@ -102,11 +107,8 @@ class DPMSolver(Scheduler):
|
||||||
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
|
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
|
||||||
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
|
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
|
||||||
self.estimated_data.append(estimated_denoised_data)
|
self.estimated_data.append(estimated_denoised_data)
|
||||||
denoised_x = (
|
|
||||||
self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
|
if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1):
|
||||||
if (self.initial_steps == 0)
|
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
|
||||||
else self.multistep_dpm_solver_second_order_update(x=x, step=step)
|
|
||||||
)
|
return self.multistep_dpm_solver_second_order_update(x=x, step=step)
|
||||||
if self.initial_steps < 2:
|
|
||||||
self.initial_steps += 1
|
|
||||||
return denoised_x
|
|
||||||
|
|
Loading…
Reference in a new issue