mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add LCMSolver (Latent Consistency Models)
This commit is contained in:
parent
4a619e84f0
commit
c8c6294550
|
@ -2,6 +2,7 @@ from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
|
|||
from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
|
||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
|
||||
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||
|
||||
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "NoiseSchedule"]
|
||||
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule"]
|
||||
|
|
122
src/refiners/foundationals/latent_diffusion/solvers/lcm.py
Normal file
122
src/refiners/foundationals/latent_diffusion/solvers/lcm.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||
|
||||
|
||||
class LCMSolver(Solver):
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
num_orig_steps: int = 50,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
diffusers_mode: bool = False,
|
||||
device: torch.device | str = "cpu",
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
assert (
|
||||
num_orig_steps >= num_inference_steps
|
||||
), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})"
|
||||
|
||||
self._dpm = [
|
||||
DPMSolver(
|
||||
num_inference_steps=num_orig_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
]
|
||||
|
||||
if diffusers_mode:
|
||||
# Diffusers recomputes the timesteps in LCMScheduler,
|
||||
# and it does it slightly differently than DPM Solver.
|
||||
# We provide this option to reproduce Diffusers' output.
|
||||
k = num_train_timesteps // num_orig_steps
|
||||
ts = np.asarray(list(range(1, num_orig_steps + 1))) * k - 1
|
||||
self.dpm.timesteps = torch.tensor(ts, device=device).flip(0)
|
||||
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
@property
|
||||
def dpm(self):
|
||||
return self._dpm[0]
|
||||
|
||||
def _generate_timesteps(self) -> torch.Tensor:
|
||||
# Note: not the same as torch.linspace(start=0, end=num_train_timesteps, steps=5)[1:],
|
||||
# e.g. for 4 steps we use [999, 759, 500, 260] instead of [999, 749, 499, 249].
|
||||
# This is due to the use of the Skipping-Steps technique during distillation,
|
||||
# see section 4.3 of the Latent Consistency Models paper (Luo 2023).
|
||||
# `k` in the paper is `num_train_timesteps / num_orig_steps`. In practice, SDXL
|
||||
# LCMs are distilled with DPM++.
|
||||
|
||||
self.timestep_indices: list[int] = (
|
||||
torch.floor(
|
||||
torch.linspace(
|
||||
start=0,
|
||||
end=self.dpm.num_inference_steps,
|
||||
steps=self.num_inference_steps + 1,
|
||||
)[:-1]
|
||||
)
|
||||
.int()
|
||||
.tolist() # type: ignore
|
||||
)
|
||||
return self.dpm.timesteps[self.timestep_indices]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
predicted_noise: torch.Tensor,
|
||||
step: int,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> torch.Tensor:
|
||||
current_timestep = self.timesteps[step]
|
||||
scale_factor = self.cumulative_scale_factors[current_timestep]
|
||||
noise_ratio = self.noise_std[current_timestep]
|
||||
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
|
||||
|
||||
# To understand the values of c_skip and c_out,
|
||||
# see "Parameterization for Consistency Models" in appendix C
|
||||
# of the Consistency Models paper (Song 2023) and Karras 2022.
|
||||
#
|
||||
# However, note that there are two major differences:
|
||||
# - epsilon is unused (= 0);
|
||||
# - c_out is missing a `sigma` factor.
|
||||
#
|
||||
# This equation is the one used in the original implementation
|
||||
# (https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7)
|
||||
# and hence the one used to train all available models.
|
||||
#
|
||||
# See https://github.com/luosiallen/latent-consistency-model/issues/82
|
||||
# for more discussion regarding this.
|
||||
|
||||
sigma = 0.5 # assume standard deviation of data distribution is 0.5
|
||||
t = current_timestep * 10 # make curve sharper
|
||||
c_skip = sigma**2 / (t**2 + sigma**2)
|
||||
c_out = t / torch.sqrt(sigma**2 + t**2)
|
||||
|
||||
denoised_x = c_skip * x + c_out * estimated_denoised_data
|
||||
|
||||
if step == self.num_inference_steps - 1:
|
||||
return denoised_x
|
||||
|
||||
# re-noise intermediate steps
|
||||
noise = torch.randn(
|
||||
predicted_noise.shape,
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
next_step = int(self.timestep_indices[step + 1])
|
||||
return self.dpm.add_noise(x=denoised_x, noise=noise, step=next_step)
|
|
@ -2,10 +2,10 @@ from typing import cast
|
|||
from warnings import warn
|
||||
|
||||
import pytest
|
||||
from torch import Tensor, allclose, device as Device, equal, isclose, randn
|
||||
from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn
|
||||
|
||||
from refiners.fluxion import manual_seed
|
||||
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, NoiseSchedule
|
||||
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, LCMSolver, NoiseSchedule
|
||||
|
||||
|
||||
def test_ddpm_diffusers():
|
||||
|
@ -100,6 +100,49 @@ def test_euler_diffusers():
|
|||
assert allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}"
|
||||
|
||||
|
||||
def test_lcm_diffusers():
|
||||
from diffusers import LCMScheduler # type: ignore
|
||||
|
||||
manual_seed(0)
|
||||
|
||||
# LCMScheduler is stochastic, make sure we use identical generators
|
||||
diffusers_generator = Generator().manual_seed(42)
|
||||
refiners_generator = Generator().manual_seed(42)
|
||||
|
||||
diffusers_scheduler = LCMScheduler()
|
||||
diffusers_scheduler.set_timesteps(4)
|
||||
refiners_scheduler = LCMSolver(num_inference_steps=4, diffusers_mode=True)
|
||||
|
||||
# diffusers_mode means the timesteps are the same
|
||||
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||
|
||||
sample = randn(1, 4, 32, 32)
|
||||
predicted_noise = randn(1, 4, 32, 32)
|
||||
|
||||
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
||||
alpha_prod_t = diffusers_scheduler.alphas_cumprod[timestep]
|
||||
diffusers_noise_ratio = (1 - alpha_prod_t).sqrt()
|
||||
diffusers_scale_factor = alpha_prod_t.sqrt()
|
||||
|
||||
refiners_scale_factor = refiners_scheduler.cumulative_scale_factors[timestep]
|
||||
refiners_noise_ratio = refiners_scheduler.noise_std[timestep]
|
||||
|
||||
assert refiners_scale_factor == diffusers_scale_factor
|
||||
assert refiners_noise_ratio == diffusers_noise_ratio
|
||||
|
||||
d_out = diffusers_scheduler.step(predicted_noise, timestep, sample, generator=diffusers_generator) # type: ignore
|
||||
diffusers_output = cast(Tensor, d_out.prev_sample) # type: ignore
|
||||
|
||||
refiners_output = refiners_scheduler(
|
||||
x=sample,
|
||||
predicted_noise=predicted_noise,
|
||||
step=step,
|
||||
generator=refiners_generator,
|
||||
)
|
||||
|
||||
assert allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}"
|
||||
|
||||
|
||||
def test_scheduler_remove_noise():
|
||||
from diffusers import DDIMScheduler # type: ignore
|
||||
|
||||
|
|
Loading…
Reference in a new issue