From 8423c5efa75aed503844bd73cabfeb527447203c Mon Sep 17 00:00:00 2001 From: Israfel Salazar Date: Wed, 10 Jan 2024 11:32:40 +0100 Subject: [PATCH] feature: Euler scheduler (#138) --- .../foundationals/latent_diffusion/model.py | 3 +- .../latent_diffusion/schedulers/__init__.py | 8 +- .../latent_diffusion/schedulers/ddim.py | 15 ++-- .../latent_diffusion/schedulers/dpm_solver.py | 15 +--- .../latent_diffusion/schedulers/euler.py | 83 +++++++++++++++++++ .../latent_diffusion/schedulers/scheduler.py | 11 ++- .../latent_diffusion/test_schedulers.py | 36 +++++++- 7 files changed, 138 insertions(+), 33 deletions(-) create mode 100644 src/refiners/foundationals/latent_diffusion/schedulers/euler.py diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index a283a0b..22618d7 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -11,7 +11,6 @@ from refiners.foundationals.latent_diffusion.schedulers.scheduler import Schedul T = TypeVar("T", bound="fl.Module") - TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel") @@ -91,6 +90,8 @@ class LatentDiffusionModel(fl.Module, ABC): self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs) latents = torch.cat(tensors=(x, x)) # for classifier-free guidance + # scale latents for schedulers that need it + latents = self.scheduler.scale_model_input(latents, step=step) unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2) # classifier-free guidance diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/__init__.py b/src/refiners/foundationals/latent_diffusion/schedulers/__init__.py index 5a9be28..23bd7f4 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/__init__.py @@ -2,10 +2,6 @@ from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler +from refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler -__all__ = [ - "Scheduler", - "DPMSolver", - "DDPM", - "DDIM", -] +__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"] diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py index 18067bd..afb6ff2 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py @@ -1,4 +1,4 @@ -from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor +from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor, Generator from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler @@ -34,7 +34,7 @@ class DDIM(Scheduler): timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1 return timesteps.flip(0) - def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor: + def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: timestep, previous_timestep = ( self.timesteps[step], ( @@ -43,13 +43,10 @@ class DDIM(Scheduler): else tensor(data=[0], device=self.device, dtype=self.dtype) ), ) - current_scale_factor, previous_scale_factor = ( - self.cumulative_scale_factors[timestep], - ( - self.cumulative_scale_factors[previous_timestep] - if previous_timestep > 0 - else self.cumulative_scale_factors[0] - ), + current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], ( + self.cumulative_scale_factors[previous_timestep] + if previous_timestep > 0 + else self.cumulative_scale_factors[0] ) predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py index e1e11dc..3b79348 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py @@ -1,9 +1,7 @@ -from collections import deque - -import numpy as np -from torch import Tensor, device as Device, dtype as Dtype, exp, float32, tensor - from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler +import numpy as np +from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype, Generator +from collections import deque class DPMSolver(Scheduler): @@ -90,12 +88,7 @@ class DPMSolver(Scheduler): ) return denoised_x - def __call__( - self, - x: Tensor, - noise: Tensor, - step: int, - ) -> Tensor: + def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: """ Represents one step of the backward diffusion process that iteratively denoises the input data `x`. diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/euler.py b/src/refiners/foundationals/latent_diffusion/schedulers/euler.py new file mode 100644 index 0000000..bb340d2 --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/schedulers/euler.py @@ -0,0 +1,83 @@ +from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler +from torch import Tensor, device as Device, dtype as Dtype, float32, tensor, Generator +import torch +import numpy as np + + +class EulerScheduler(Scheduler): + def __init__( + self, + num_inference_steps: int, + num_train_timesteps: int = 1_000, + initial_diffusion_rate: float = 8.5e-4, + final_diffusion_rate: float = 1.2e-2, + noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, + device: Device | str = "cpu", + dtype: Dtype = float32, + ): + if noise_schedule != NoiseSchedule.QUADRATIC: + raise NotImplementedError + super().__init__( + num_inference_steps=num_inference_steps, + num_train_timesteps=num_train_timesteps, + initial_diffusion_rate=initial_diffusion_rate, + final_diffusion_rate=final_diffusion_rate, + noise_schedule=noise_schedule, + device=device, + dtype=dtype, + ) + self.sigmas = self._generate_sigmas() + + @property + def init_noise_sigma(self) -> Tensor: + return self.sigmas.max() + + def _generate_timesteps(self) -> Tensor: + # We need to use numpy here because: + # numpy.linspace(0,999,31)[15] is 499.49999999999994 + # torch.linspace(0,999,31)[15] is 499.5 + # ...and we want the same result as the original codebase. + timesteps = torch.tensor( + np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device + ).flip(0) + return timesteps + + def _generate_sigmas(self) -> Tensor: + sigmas = self.noise_std / self.cumulative_scale_factors + sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy())) + sigmas = torch.cat([sigmas, tensor([0.0])]) + return sigmas.to(device=self.device, dtype=self.dtype) + + def scale_model_input(self, x: Tensor, step: int) -> Tensor: + sigma = self.sigmas[step] + return x / ((sigma**2 + 1) ** 0.5) + + def __call__( + self, + x: Tensor, + noise: Tensor, + step: int, + generator: Generator | None = None, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + ) -> Tensor: + sigma = self.sigmas[step] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0 + + alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype) + eps = alt_noise * s_noise + sigma_hat = sigma * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + predicted_x = x - sigma_hat * noise + + # 1st order Euler + derivative = (x - predicted_x) / sigma_hat + dt = self.sigmas[step + 1] - sigma_hat + denoised_x = x + derivative * dt + + return denoised_x diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py index abf106c..f570ad8 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py @@ -1,9 +1,8 @@ from abc import ABC, abstractmethod from enum import Enum +from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log, Generator from typing import TypeVar -from torch import Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt - T = TypeVar("T", bound="Scheduler") @@ -50,7 +49,7 @@ class Scheduler(ABC): self.timesteps = self._generate_timesteps() @abstractmethod - def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor: + def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: """ Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`. @@ -71,6 +70,12 @@ class Scheduler(ABC): def steps(self) -> list[int]: return list(range(self.num_inference_steps)) + def scale_model_input(self, x: Tensor, step: int) -> Tensor: + """ + For compatibility with schedulers that need to scale the input according to the current timestep. + """ + return x + def sample_power_distribution(self, power: float = 2, /) -> Tensor: return ( linspace( diff --git a/tests/foundationals/latent_diffusion/test_schedulers.py b/tests/foundationals/latent_diffusion/test_schedulers.py index fc94594..c8a1c98 100644 --- a/tests/foundationals/latent_diffusion/test_schedulers.py +++ b/tests/foundationals/latent_diffusion/test_schedulers.py @@ -2,10 +2,10 @@ from typing import cast from warnings import warn import pytest -from torch import Tensor, allclose, device as Device, equal, randn +from torch import Tensor, allclose, device as Device, equal, randn, isclose from refiners.fluxion import manual_seed -from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver +from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver, EulerScheduler def test_ddpm_diffusers(): @@ -63,6 +63,34 @@ def test_ddim_diffusers(): assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}" +def test_euler_diffusers(): + from diffusers import EulerDiscreteScheduler + + manual_seed(0) + diffusers_scheduler = EulerDiscreteScheduler( + beta_end=0.012, + beta_schedule="scaled_linear", + beta_start=0.00085, + num_train_timesteps=1000, + steps_offset=1, + timestep_spacing="linspace", + use_karras_sigmas=False, + ) + diffusers_scheduler.set_timesteps(30) + refiners_scheduler = EulerScheduler(num_inference_steps=30) + + sample = randn(1, 4, 32, 32) + noise = randn(1, 4, 32, 32) + + assert isclose(diffusers_scheduler.init_noise_sigma, refiners_scheduler.init_noise_sigma), "init_noise_sigma differ" + + for step, timestep in enumerate(diffusers_scheduler.timesteps): + diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore + refiners_output = refiners_scheduler(x=sample, noise=noise, step=step) + + assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}" + + def test_scheduler_remove_noise(): from diffusers import DDIMScheduler # type: ignore @@ -84,7 +112,9 @@ def test_scheduler_remove_noise(): noise = randn(1, 4, 32, 32) for step, timestep in enumerate(diffusers_scheduler.timesteps): - diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore + diffusers_output = cast( + Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample + ) # type: ignore refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step) assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"