From c8c6294550ecfe529c76e606ee07b1b7329f5b24 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Thu, 15 Feb 2024 18:41:00 +0100 Subject: [PATCH] add LCMSolver (Latent Consistency Models) --- .../latent_diffusion/solvers/__init__.py | 3 +- .../latent_diffusion/solvers/lcm.py | 122 ++++++++++++++++++ .../latent_diffusion/test_solvers.py | 47 ++++++- 3 files changed, 169 insertions(+), 3 deletions(-) create mode 100644 src/refiners/foundationals/latent_diffusion/solvers/lcm.py diff --git a/src/refiners/foundationals/latent_diffusion/solvers/__init__.py b/src/refiners/foundationals/latent_diffusion/solvers/__init__.py index 72da806..6a9aa07 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/__init__.py @@ -2,6 +2,7 @@ from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver from refiners.foundationals.latent_diffusion.solvers.euler import Euler +from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver -__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "NoiseSchedule"] +__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule"] diff --git a/src/refiners/foundationals/latent_diffusion/solvers/lcm.py b/src/refiners/foundationals/latent_diffusion/solvers/lcm.py new file mode 100644 index 0000000..d0189c0 --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/solvers/lcm.py @@ -0,0 +1,122 @@ +import numpy as np +import torch + +from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver +from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver + + +class LCMSolver(Solver): + def __init__( + self, + num_inference_steps: int, + num_train_timesteps: int = 1_000, + num_orig_steps: int = 50, + initial_diffusion_rate: float = 8.5e-4, + final_diffusion_rate: float = 1.2e-2, + noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, + diffusers_mode: bool = False, + device: torch.device | str = "cpu", + dtype: torch.dtype = torch.float32, + ): + assert ( + num_orig_steps >= num_inference_steps + ), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})" + + self._dpm = [ + DPMSolver( + num_inference_steps=num_orig_steps, + num_train_timesteps=num_train_timesteps, + device=device, + dtype=dtype, + ) + ] + + if diffusers_mode: + # Diffusers recomputes the timesteps in LCMScheduler, + # and it does it slightly differently than DPM Solver. + # We provide this option to reproduce Diffusers' output. + k = num_train_timesteps // num_orig_steps + ts = np.asarray(list(range(1, num_orig_steps + 1))) * k - 1 + self.dpm.timesteps = torch.tensor(ts, device=device).flip(0) + + super().__init__( + num_inference_steps=num_inference_steps, + num_train_timesteps=num_train_timesteps, + initial_diffusion_rate=initial_diffusion_rate, + final_diffusion_rate=final_diffusion_rate, + noise_schedule=noise_schedule, + device=device, + dtype=dtype, + ) + + @property + def dpm(self): + return self._dpm[0] + + def _generate_timesteps(self) -> torch.Tensor: + # Note: not the same as torch.linspace(start=0, end=num_train_timesteps, steps=5)[1:], + # e.g. for 4 steps we use [999, 759, 500, 260] instead of [999, 749, 499, 249]. + # This is due to the use of the Skipping-Steps technique during distillation, + # see section 4.3 of the Latent Consistency Models paper (Luo 2023). + # `k` in the paper is `num_train_timesteps / num_orig_steps`. In practice, SDXL + # LCMs are distilled with DPM++. + + self.timestep_indices: list[int] = ( + torch.floor( + torch.linspace( + start=0, + end=self.dpm.num_inference_steps, + steps=self.num_inference_steps + 1, + )[:-1] + ) + .int() + .tolist() # type: ignore + ) + return self.dpm.timesteps[self.timestep_indices] + + def __call__( + self, + x: torch.Tensor, + predicted_noise: torch.Tensor, + step: int, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + current_timestep = self.timesteps[step] + scale_factor = self.cumulative_scale_factors[current_timestep] + noise_ratio = self.noise_std[current_timestep] + estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor + + # To understand the values of c_skip and c_out, + # see "Parameterization for Consistency Models" in appendix C + # of the Consistency Models paper (Song 2023) and Karras 2022. + # + # However, note that there are two major differences: + # - epsilon is unused (= 0); + # - c_out is missing a `sigma` factor. + # + # This equation is the one used in the original implementation + # (https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) + # and hence the one used to train all available models. + # + # See https://github.com/luosiallen/latent-consistency-model/issues/82 + # for more discussion regarding this. + + sigma = 0.5 # assume standard deviation of data distribution is 0.5 + t = current_timestep * 10 # make curve sharper + c_skip = sigma**2 / (t**2 + sigma**2) + c_out = t / torch.sqrt(sigma**2 + t**2) + + denoised_x = c_skip * x + c_out * estimated_denoised_data + + if step == self.num_inference_steps - 1: + return denoised_x + + # re-noise intermediate steps + noise = torch.randn( + predicted_noise.shape, + generator=generator, + device=self.device, + dtype=self.dtype, + ) + next_step = int(self.timestep_indices[step + 1]) + return self.dpm.add_noise(x=denoised_x, noise=noise, step=next_step) diff --git a/tests/foundationals/latent_diffusion/test_solvers.py b/tests/foundationals/latent_diffusion/test_solvers.py index cd671c1..97444c3 100644 --- a/tests/foundationals/latent_diffusion/test_solvers.py +++ b/tests/foundationals/latent_diffusion/test_solvers.py @@ -2,10 +2,10 @@ from typing import cast from warnings import warn import pytest -from torch import Tensor, allclose, device as Device, equal, isclose, randn +from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn from refiners.fluxion import manual_seed -from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, NoiseSchedule +from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, LCMSolver, NoiseSchedule def test_ddpm_diffusers(): @@ -100,6 +100,49 @@ def test_euler_diffusers(): assert allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}" +def test_lcm_diffusers(): + from diffusers import LCMScheduler # type: ignore + + manual_seed(0) + + # LCMScheduler is stochastic, make sure we use identical generators + diffusers_generator = Generator().manual_seed(42) + refiners_generator = Generator().manual_seed(42) + + diffusers_scheduler = LCMScheduler() + diffusers_scheduler.set_timesteps(4) + refiners_scheduler = LCMSolver(num_inference_steps=4, diffusers_mode=True) + + # diffusers_mode means the timesteps are the same + assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps) + + sample = randn(1, 4, 32, 32) + predicted_noise = randn(1, 4, 32, 32) + + for step, timestep in enumerate(diffusers_scheduler.timesteps): + alpha_prod_t = diffusers_scheduler.alphas_cumprod[timestep] + diffusers_noise_ratio = (1 - alpha_prod_t).sqrt() + diffusers_scale_factor = alpha_prod_t.sqrt() + + refiners_scale_factor = refiners_scheduler.cumulative_scale_factors[timestep] + refiners_noise_ratio = refiners_scheduler.noise_std[timestep] + + assert refiners_scale_factor == diffusers_scale_factor + assert refiners_noise_ratio == diffusers_noise_ratio + + d_out = diffusers_scheduler.step(predicted_noise, timestep, sample, generator=diffusers_generator) # type: ignore + diffusers_output = cast(Tensor, d_out.prev_sample) # type: ignore + + refiners_output = refiners_scheduler( + x=sample, + predicted_noise=predicted_noise, + step=step, + generator=refiners_generator, + ) + + assert allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}" + + def test_scheduler_remove_noise(): from diffusers import DDIMScheduler # type: ignore