mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
add LCMSolver (Latent Consistency Models)
This commit is contained in:
parent
4a619e84f0
commit
c8c6294550
|
@ -2,6 +2,7 @@ from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
|
||||||
from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
|
from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
|
||||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||||
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
|
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
|
||||||
|
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||||
|
|
||||||
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "NoiseSchedule"]
|
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule"]
|
||||||
|
|
122
src/refiners/foundationals/latent_diffusion/solvers/lcm.py
Normal file
122
src/refiners/foundationals/latent_diffusion/solvers/lcm.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||||
|
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||||
|
|
||||||
|
|
||||||
|
class LCMSolver(Solver):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_inference_steps: int,
|
||||||
|
num_train_timesteps: int = 1_000,
|
||||||
|
num_orig_steps: int = 50,
|
||||||
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
|
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||||
|
diffusers_mode: bool = False,
|
||||||
|
device: torch.device | str = "cpu",
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
num_orig_steps >= num_inference_steps
|
||||||
|
), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})"
|
||||||
|
|
||||||
|
self._dpm = [
|
||||||
|
DPMSolver(
|
||||||
|
num_inference_steps=num_orig_steps,
|
||||||
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if diffusers_mode:
|
||||||
|
# Diffusers recomputes the timesteps in LCMScheduler,
|
||||||
|
# and it does it slightly differently than DPM Solver.
|
||||||
|
# We provide this option to reproduce Diffusers' output.
|
||||||
|
k = num_train_timesteps // num_orig_steps
|
||||||
|
ts = np.asarray(list(range(1, num_orig_steps + 1))) * k - 1
|
||||||
|
self.dpm.timesteps = torch.tensor(ts, device=device).flip(0)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
|
noise_schedule=noise_schedule,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dpm(self):
|
||||||
|
return self._dpm[0]
|
||||||
|
|
||||||
|
def _generate_timesteps(self) -> torch.Tensor:
|
||||||
|
# Note: not the same as torch.linspace(start=0, end=num_train_timesteps, steps=5)[1:],
|
||||||
|
# e.g. for 4 steps we use [999, 759, 500, 260] instead of [999, 749, 499, 249].
|
||||||
|
# This is due to the use of the Skipping-Steps technique during distillation,
|
||||||
|
# see section 4.3 of the Latent Consistency Models paper (Luo 2023).
|
||||||
|
# `k` in the paper is `num_train_timesteps / num_orig_steps`. In practice, SDXL
|
||||||
|
# LCMs are distilled with DPM++.
|
||||||
|
|
||||||
|
self.timestep_indices: list[int] = (
|
||||||
|
torch.floor(
|
||||||
|
torch.linspace(
|
||||||
|
start=0,
|
||||||
|
end=self.dpm.num_inference_steps,
|
||||||
|
steps=self.num_inference_steps + 1,
|
||||||
|
)[:-1]
|
||||||
|
)
|
||||||
|
.int()
|
||||||
|
.tolist() # type: ignore
|
||||||
|
)
|
||||||
|
return self.dpm.timesteps[self.timestep_indices]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
predicted_noise: torch.Tensor,
|
||||||
|
step: int,
|
||||||
|
generator: torch.Generator | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
current_timestep = self.timesteps[step]
|
||||||
|
scale_factor = self.cumulative_scale_factors[current_timestep]
|
||||||
|
noise_ratio = self.noise_std[current_timestep]
|
||||||
|
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
|
||||||
|
|
||||||
|
# To understand the values of c_skip and c_out,
|
||||||
|
# see "Parameterization for Consistency Models" in appendix C
|
||||||
|
# of the Consistency Models paper (Song 2023) and Karras 2022.
|
||||||
|
#
|
||||||
|
# However, note that there are two major differences:
|
||||||
|
# - epsilon is unused (= 0);
|
||||||
|
# - c_out is missing a `sigma` factor.
|
||||||
|
#
|
||||||
|
# This equation is the one used in the original implementation
|
||||||
|
# (https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7)
|
||||||
|
# and hence the one used to train all available models.
|
||||||
|
#
|
||||||
|
# See https://github.com/luosiallen/latent-consistency-model/issues/82
|
||||||
|
# for more discussion regarding this.
|
||||||
|
|
||||||
|
sigma = 0.5 # assume standard deviation of data distribution is 0.5
|
||||||
|
t = current_timestep * 10 # make curve sharper
|
||||||
|
c_skip = sigma**2 / (t**2 + sigma**2)
|
||||||
|
c_out = t / torch.sqrt(sigma**2 + t**2)
|
||||||
|
|
||||||
|
denoised_x = c_skip * x + c_out * estimated_denoised_data
|
||||||
|
|
||||||
|
if step == self.num_inference_steps - 1:
|
||||||
|
return denoised_x
|
||||||
|
|
||||||
|
# re-noise intermediate steps
|
||||||
|
noise = torch.randn(
|
||||||
|
predicted_noise.shape,
|
||||||
|
generator=generator,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
next_step = int(self.timestep_indices[step + 1])
|
||||||
|
return self.dpm.add_noise(x=denoised_x, noise=noise, step=next_step)
|
|
@ -2,10 +2,10 @@ from typing import cast
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from torch import Tensor, allclose, device as Device, equal, isclose, randn
|
from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn
|
||||||
|
|
||||||
from refiners.fluxion import manual_seed
|
from refiners.fluxion import manual_seed
|
||||||
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, NoiseSchedule
|
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, LCMSolver, NoiseSchedule
|
||||||
|
|
||||||
|
|
||||||
def test_ddpm_diffusers():
|
def test_ddpm_diffusers():
|
||||||
|
@ -100,6 +100,49 @@ def test_euler_diffusers():
|
||||||
assert allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}"
|
assert allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lcm_diffusers():
|
||||||
|
from diffusers import LCMScheduler # type: ignore
|
||||||
|
|
||||||
|
manual_seed(0)
|
||||||
|
|
||||||
|
# LCMScheduler is stochastic, make sure we use identical generators
|
||||||
|
diffusers_generator = Generator().manual_seed(42)
|
||||||
|
refiners_generator = Generator().manual_seed(42)
|
||||||
|
|
||||||
|
diffusers_scheduler = LCMScheduler()
|
||||||
|
diffusers_scheduler.set_timesteps(4)
|
||||||
|
refiners_scheduler = LCMSolver(num_inference_steps=4, diffusers_mode=True)
|
||||||
|
|
||||||
|
# diffusers_mode means the timesteps are the same
|
||||||
|
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||||
|
|
||||||
|
sample = randn(1, 4, 32, 32)
|
||||||
|
predicted_noise = randn(1, 4, 32, 32)
|
||||||
|
|
||||||
|
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
||||||
|
alpha_prod_t = diffusers_scheduler.alphas_cumprod[timestep]
|
||||||
|
diffusers_noise_ratio = (1 - alpha_prod_t).sqrt()
|
||||||
|
diffusers_scale_factor = alpha_prod_t.sqrt()
|
||||||
|
|
||||||
|
refiners_scale_factor = refiners_scheduler.cumulative_scale_factors[timestep]
|
||||||
|
refiners_noise_ratio = refiners_scheduler.noise_std[timestep]
|
||||||
|
|
||||||
|
assert refiners_scale_factor == diffusers_scale_factor
|
||||||
|
assert refiners_noise_ratio == diffusers_noise_ratio
|
||||||
|
|
||||||
|
d_out = diffusers_scheduler.step(predicted_noise, timestep, sample, generator=diffusers_generator) # type: ignore
|
||||||
|
diffusers_output = cast(Tensor, d_out.prev_sample) # type: ignore
|
||||||
|
|
||||||
|
refiners_output = refiners_scheduler(
|
||||||
|
x=sample,
|
||||||
|
predicted_noise=predicted_noise,
|
||||||
|
step=step,
|
||||||
|
generator=refiners_generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
def test_scheduler_remove_noise():
|
def test_scheduler_remove_noise():
|
||||||
from diffusers import DDIMScheduler # type: ignore
|
from diffusers import DDIMScheduler # type: ignore
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue