From 9e8c2a3753e54bd996a0f64b63d8b5e73215a5fb Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Wed, 10 Jul 2024 16:08:05 +0200 Subject: [PATCH] add FrankenSolver This solver is designed to use Diffusers Schedulers as Refiners Solvers. --- .../latent_diffusion/solvers/__init__.py | 2 + .../latent_diffusion/solvers/franken.py | 57 +++++++++++++++++++ .../latent_diffusion/test_solvers.py | 37 ++++++++++++ 3 files changed, 96 insertions(+) create mode 100644 src/refiners/foundationals/latent_diffusion/solvers/franken.py diff --git a/src/refiners/foundationals/latent_diffusion/solvers/__init__.py b/src/refiners/foundationals/latent_diffusion/solvers/__init__.py index 7f904a5..737bab5 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/__init__.py @@ -2,6 +2,7 @@ from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver from refiners.foundationals.latent_diffusion.solvers.euler import Euler +from refiners.foundationals.latent_diffusion.solvers.franken import FrankenSolver from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver from refiners.foundationals.latent_diffusion.solvers.solver import ( ModelPredictionType, @@ -18,6 +19,7 @@ __all__ = [ "DDPM", "DDIM", "Euler", + "FrankenSolver", "LCMSolver", "ModelPredictionType", "NoiseSchedule", diff --git a/src/refiners/foundationals/latent_diffusion/solvers/franken.py b/src/refiners/foundationals/latent_diffusion/solvers/franken.py new file mode 100644 index 0000000..2842e67 --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/solvers/franken.py @@ -0,0 +1,57 @@ +import dataclasses +from typing import Any, cast + +from torch import Generator, Tensor + +from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing + + +class FrankenSolver(Solver): + """Lets you use Diffusers Schedulers as Refiners Solvers. + + For instance: + from diffusers import EulerDiscreteScheduler + from refiners.foundationals.latent_diffusion.solvers import FrankenSolver + + scheduler = EulerDiscreteScheduler(...) + solver = FrankenSolver(scheduler, num_inference_steps=steps) + """ + + default_params = dataclasses.replace( + Solver.default_params, + timesteps_spacing=TimestepSpacing.CUSTOM, + ) + + def __init__( + self, + diffusers_scheduler: Any, + num_inference_steps: int, + first_inference_step: int = 0, + **kwargs: Any, + ) -> None: + self.diffusers_scheduler = diffusers_scheduler + diffusers_scheduler.set_timesteps(num_inference_steps) + super().__init__(num_inference_steps=num_inference_steps, first_inference_step=first_inference_step) + + def _generate_timesteps(self) -> Tensor: + return self.diffusers_scheduler.timesteps + + def rebuild( + self, + num_inference_steps: int | None, + first_inference_step: int | None = None, + ) -> "FrankenSolver": + return self.__class__( + diffusers_scheduler=self.diffusers_scheduler, + num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps, + first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step, + ) + + def scale_model_input(self, x: Tensor, step: int) -> Tensor: + if step == -1: + return x * self.diffusers_scheduler.init_noise_sigma + return self.diffusers_scheduler.scale_model_input(x, self.timesteps[step]) + + def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: + timestep = self.timesteps[step] + return cast(Tensor, self.diffusers_scheduler.step(predicted_noise, timestep, x).prev_sample) diff --git a/tests/foundationals/latent_diffusion/test_solvers.py b/tests/foundationals/latent_diffusion/test_solvers.py index c653c48..dfc2dcd 100644 --- a/tests/foundationals/latent_diffusion/test_solvers.py +++ b/tests/foundationals/latent_diffusion/test_solvers.py @@ -10,6 +10,7 @@ from refiners.foundationals.latent_diffusion.solvers import ( DDPM, DPMSolver, Euler, + FrankenSolver, LCMSolver, ModelPredictionType, NoiseSchedule, @@ -119,6 +120,42 @@ def test_euler_diffusers(model_prediction_type: ModelPredictionType): assert allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}" +def test_franken_diffusers(): + from diffusers import EulerDiscreteScheduler # type: ignore + + manual_seed(0) + params = { + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "num_train_timesteps": 1000, + "steps_offset": 1, + "timestep_spacing": "linspace", + "use_karras_sigmas": False, + } + + diffusers_scheduler = EulerDiscreteScheduler(**params) # type: ignore + diffusers_scheduler.set_timesteps(30) + + diffusers_scheduler_2 = EulerDiscreteScheduler(**params) # type: ignore + refiners_scheduler = FrankenSolver(diffusers_scheduler_2, num_inference_steps=30) + assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps) + + sample = randn(1, 4, 32, 32) + predicted_noise = randn(1, 4, 32, 32) + + ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore + assert isinstance(ref_init_noise_sigma, Tensor) + init_noise_sigma = refiners_scheduler.scale_model_input(tensor(1), step=-1) + assert equal(ref_init_noise_sigma, init_noise_sigma), "init_noise_sigma differ" + + for step, timestep in enumerate(diffusers_scheduler.timesteps): + diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore + refiners_output = refiners_scheduler(x=sample, predicted_noise=predicted_noise, step=step) + + assert equal(diffusers_output, refiners_output), f"outputs differ at step {step}" + + def test_lcm_diffusers(): from diffusers import LCMScheduler # type: ignore