add FrankenSolver

This solver is designed to use Diffusers Schedulers as Refiners Solvers.
This commit is contained in:
Pierre Chapuis 2024-07-10 16:08:05 +02:00
parent e091788b88
commit 9e8c2a3753
3 changed files with 96 additions and 0 deletions

View file

@ -2,6 +2,7 @@ from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.euler import Euler from refiners.foundationals.latent_diffusion.solvers.euler import Euler
from refiners.foundationals.latent_diffusion.solvers.franken import FrankenSolver
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
from refiners.foundationals.latent_diffusion.solvers.solver import ( from refiners.foundationals.latent_diffusion.solvers.solver import (
ModelPredictionType, ModelPredictionType,
@ -18,6 +19,7 @@ __all__ = [
"DDPM", "DDPM",
"DDIM", "DDIM",
"Euler", "Euler",
"FrankenSolver",
"LCMSolver", "LCMSolver",
"ModelPredictionType", "ModelPredictionType",
"NoiseSchedule", "NoiseSchedule",

View file

@ -0,0 +1,57 @@
import dataclasses
from typing import Any, cast
from torch import Generator, Tensor
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
class FrankenSolver(Solver):
"""Lets you use Diffusers Schedulers as Refiners Solvers.
For instance:
from diffusers import EulerDiscreteScheduler
from refiners.foundationals.latent_diffusion.solvers import FrankenSolver
scheduler = EulerDiscreteScheduler(...)
solver = FrankenSolver(scheduler, num_inference_steps=steps)
"""
default_params = dataclasses.replace(
Solver.default_params,
timesteps_spacing=TimestepSpacing.CUSTOM,
)
def __init__(
self,
diffusers_scheduler: Any,
num_inference_steps: int,
first_inference_step: int = 0,
**kwargs: Any,
) -> None:
self.diffusers_scheduler = diffusers_scheduler
diffusers_scheduler.set_timesteps(num_inference_steps)
super().__init__(num_inference_steps=num_inference_steps, first_inference_step=first_inference_step)
def _generate_timesteps(self) -> Tensor:
return self.diffusers_scheduler.timesteps
def rebuild(
self,
num_inference_steps: int | None,
first_inference_step: int | None = None,
) -> "FrankenSolver":
return self.__class__(
diffusers_scheduler=self.diffusers_scheduler,
num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,
first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,
)
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
if step == -1:
return x * self.diffusers_scheduler.init_noise_sigma
return self.diffusers_scheduler.scale_model_input(x, self.timesteps[step])
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
timestep = self.timesteps[step]
return cast(Tensor, self.diffusers_scheduler.step(predicted_noise, timestep, x).prev_sample)

View file

@ -10,6 +10,7 @@ from refiners.foundationals.latent_diffusion.solvers import (
DDPM, DDPM,
DPMSolver, DPMSolver,
Euler, Euler,
FrankenSolver,
LCMSolver, LCMSolver,
ModelPredictionType, ModelPredictionType,
NoiseSchedule, NoiseSchedule,
@ -119,6 +120,42 @@ def test_euler_diffusers(model_prediction_type: ModelPredictionType):
assert allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}" assert allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}"
def test_franken_diffusers():
from diffusers import EulerDiscreteScheduler # type: ignore
manual_seed(0)
params = {
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"num_train_timesteps": 1000,
"steps_offset": 1,
"timestep_spacing": "linspace",
"use_karras_sigmas": False,
}
diffusers_scheduler = EulerDiscreteScheduler(**params) # type: ignore
diffusers_scheduler.set_timesteps(30)
diffusers_scheduler_2 = EulerDiscreteScheduler(**params) # type: ignore
refiners_scheduler = FrankenSolver(diffusers_scheduler_2, num_inference_steps=30)
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32)
ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore
assert isinstance(ref_init_noise_sigma, Tensor)
init_noise_sigma = refiners_scheduler.scale_model_input(tensor(1), step=-1)
assert equal(ref_init_noise_sigma, init_noise_sigma), "init_noise_sigma differ"
for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, predicted_noise=predicted_noise, step=step)
assert equal(diffusers_output, refiners_output), f"outputs differ at step {step}"
def test_lcm_diffusers(): def test_lcm_diffusers():
from diffusers import LCMScheduler # type: ignore from diffusers import LCMScheduler # type: ignore