add LCMSolver (Latent Consistency Models)

This commit is contained in:
Pierre Chapuis 2024-02-15 18:41:00 +01:00
parent 4a619e84f0
commit c8c6294550
3 changed files with 169 additions and 3 deletions

View file

@ -2,6 +2,7 @@ from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "NoiseSchedule"]
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule"]

View file

@ -0,0 +1,122 @@
import numpy as np
import torch
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
class LCMSolver(Solver):
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
num_orig_steps: int = 50,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
diffusers_mode: bool = False,
device: torch.device | str = "cpu",
dtype: torch.dtype = torch.float32,
):
assert (
num_orig_steps >= num_inference_steps
), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})"
self._dpm = [
DPMSolver(
num_inference_steps=num_orig_steps,
num_train_timesteps=num_train_timesteps,
device=device,
dtype=dtype,
)
]
if diffusers_mode:
# Diffusers recomputes the timesteps in LCMScheduler,
# and it does it slightly differently than DPM Solver.
# We provide this option to reproduce Diffusers' output.
k = num_train_timesteps // num_orig_steps
ts = np.asarray(list(range(1, num_orig_steps + 1))) * k - 1
self.dpm.timesteps = torch.tensor(ts, device=device).flip(0)
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
device=device,
dtype=dtype,
)
@property
def dpm(self):
return self._dpm[0]
def _generate_timesteps(self) -> torch.Tensor:
# Note: not the same as torch.linspace(start=0, end=num_train_timesteps, steps=5)[1:],
# e.g. for 4 steps we use [999, 759, 500, 260] instead of [999, 749, 499, 249].
# This is due to the use of the Skipping-Steps technique during distillation,
# see section 4.3 of the Latent Consistency Models paper (Luo 2023).
# `k` in the paper is `num_train_timesteps / num_orig_steps`. In practice, SDXL
# LCMs are distilled with DPM++.
self.timestep_indices: list[int] = (
torch.floor(
torch.linspace(
start=0,
end=self.dpm.num_inference_steps,
steps=self.num_inference_steps + 1,
)[:-1]
)
.int()
.tolist() # type: ignore
)
return self.dpm.timesteps[self.timestep_indices]
def __call__(
self,
x: torch.Tensor,
predicted_noise: torch.Tensor,
step: int,
generator: torch.Generator | None = None,
) -> torch.Tensor:
current_timestep = self.timesteps[step]
scale_factor = self.cumulative_scale_factors[current_timestep]
noise_ratio = self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
# To understand the values of c_skip and c_out,
# see "Parameterization for Consistency Models" in appendix C
# of the Consistency Models paper (Song 2023) and Karras 2022.
#
# However, note that there are two major differences:
# - epsilon is unused (= 0);
# - c_out is missing a `sigma` factor.
#
# This equation is the one used in the original implementation
# (https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7)
# and hence the one used to train all available models.
#
# See https://github.com/luosiallen/latent-consistency-model/issues/82
# for more discussion regarding this.
sigma = 0.5 # assume standard deviation of data distribution is 0.5
t = current_timestep * 10 # make curve sharper
c_skip = sigma**2 / (t**2 + sigma**2)
c_out = t / torch.sqrt(sigma**2 + t**2)
denoised_x = c_skip * x + c_out * estimated_denoised_data
if step == self.num_inference_steps - 1:
return denoised_x
# re-noise intermediate steps
noise = torch.randn(
predicted_noise.shape,
generator=generator,
device=self.device,
dtype=self.dtype,
)
next_step = int(self.timestep_indices[step + 1])
return self.dpm.add_noise(x=denoised_x, noise=noise, step=next_step)

View file

@ -2,10 +2,10 @@ from typing import cast
from warnings import warn
import pytest
from torch import Tensor, allclose, device as Device, equal, isclose, randn
from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn
from refiners.fluxion import manual_seed
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, NoiseSchedule
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, LCMSolver, NoiseSchedule
def test_ddpm_diffusers():
@ -100,6 +100,49 @@ def test_euler_diffusers():
assert allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}"
def test_lcm_diffusers():
from diffusers import LCMScheduler # type: ignore
manual_seed(0)
# LCMScheduler is stochastic, make sure we use identical generators
diffusers_generator = Generator().manual_seed(42)
refiners_generator = Generator().manual_seed(42)
diffusers_scheduler = LCMScheduler()
diffusers_scheduler.set_timesteps(4)
refiners_scheduler = LCMSolver(num_inference_steps=4, diffusers_mode=True)
# diffusers_mode means the timesteps are the same
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32)
for step, timestep in enumerate(diffusers_scheduler.timesteps):
alpha_prod_t = diffusers_scheduler.alphas_cumprod[timestep]
diffusers_noise_ratio = (1 - alpha_prod_t).sqrt()
diffusers_scale_factor = alpha_prod_t.sqrt()
refiners_scale_factor = refiners_scheduler.cumulative_scale_factors[timestep]
refiners_noise_ratio = refiners_scheduler.noise_std[timestep]
assert refiners_scale_factor == diffusers_scale_factor
assert refiners_noise_ratio == diffusers_noise_ratio
d_out = diffusers_scheduler.step(predicted_noise, timestep, sample, generator=diffusers_generator) # type: ignore
diffusers_output = cast(Tensor, d_out.prev_sample) # type: ignore
refiners_output = refiners_scheduler(
x=sample,
predicted_noise=predicted_noise,
step=step,
generator=refiners_generator,
)
assert allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}"
def test_scheduler_remove_noise():
from diffusers import DDIMScheduler # type: ignore