mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add FrankenSolver
This solver is designed to use Diffusers Schedulers as Refiners Solvers.
This commit is contained in:
parent
e091788b88
commit
9e8c2a3753
|
@ -2,6 +2,7 @@ from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
|
|||
from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
|
||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
|
||||
from refiners.foundationals.latent_diffusion.solvers.franken import FrankenSolver
|
||||
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
ModelPredictionType,
|
||||
|
@ -18,6 +19,7 @@ __all__ = [
|
|||
"DDPM",
|
||||
"DDIM",
|
||||
"Euler",
|
||||
"FrankenSolver",
|
||||
"LCMSolver",
|
||||
"ModelPredictionType",
|
||||
"NoiseSchedule",
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
import dataclasses
|
||||
from typing import Any, cast
|
||||
|
||||
from torch import Generator, Tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
|
||||
|
||||
|
||||
class FrankenSolver(Solver):
|
||||
"""Lets you use Diffusers Schedulers as Refiners Solvers.
|
||||
|
||||
For instance:
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
from refiners.foundationals.latent_diffusion.solvers import FrankenSolver
|
||||
|
||||
scheduler = EulerDiscreteScheduler(...)
|
||||
solver = FrankenSolver(scheduler, num_inference_steps=steps)
|
||||
"""
|
||||
|
||||
default_params = dataclasses.replace(
|
||||
Solver.default_params,
|
||||
timesteps_spacing=TimestepSpacing.CUSTOM,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
diffusers_scheduler: Any,
|
||||
num_inference_steps: int,
|
||||
first_inference_step: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.diffusers_scheduler = diffusers_scheduler
|
||||
diffusers_scheduler.set_timesteps(num_inference_steps)
|
||||
super().__init__(num_inference_steps=num_inference_steps, first_inference_step=first_inference_step)
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
return self.diffusers_scheduler.timesteps
|
||||
|
||||
def rebuild(
|
||||
self,
|
||||
num_inference_steps: int | None,
|
||||
first_inference_step: int | None = None,
|
||||
) -> "FrankenSolver":
|
||||
return self.__class__(
|
||||
diffusers_scheduler=self.diffusers_scheduler,
|
||||
num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,
|
||||
first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,
|
||||
)
|
||||
|
||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||
if step == -1:
|
||||
return x * self.diffusers_scheduler.init_noise_sigma
|
||||
return self.diffusers_scheduler.scale_model_input(x, self.timesteps[step])
|
||||
|
||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
timestep = self.timesteps[step]
|
||||
return cast(Tensor, self.diffusers_scheduler.step(predicted_noise, timestep, x).prev_sample)
|
|
@ -10,6 +10,7 @@ from refiners.foundationals.latent_diffusion.solvers import (
|
|||
DDPM,
|
||||
DPMSolver,
|
||||
Euler,
|
||||
FrankenSolver,
|
||||
LCMSolver,
|
||||
ModelPredictionType,
|
||||
NoiseSchedule,
|
||||
|
@ -119,6 +120,42 @@ def test_euler_diffusers(model_prediction_type: ModelPredictionType):
|
|||
assert allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}"
|
||||
|
||||
|
||||
def test_franken_diffusers():
|
||||
from diffusers import EulerDiscreteScheduler # type: ignore
|
||||
|
||||
manual_seed(0)
|
||||
params = {
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"num_train_timesteps": 1000,
|
||||
"steps_offset": 1,
|
||||
"timestep_spacing": "linspace",
|
||||
"use_karras_sigmas": False,
|
||||
}
|
||||
|
||||
diffusers_scheduler = EulerDiscreteScheduler(**params) # type: ignore
|
||||
diffusers_scheduler.set_timesteps(30)
|
||||
|
||||
diffusers_scheduler_2 = EulerDiscreteScheduler(**params) # type: ignore
|
||||
refiners_scheduler = FrankenSolver(diffusers_scheduler_2, num_inference_steps=30)
|
||||
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||
|
||||
sample = randn(1, 4, 32, 32)
|
||||
predicted_noise = randn(1, 4, 32, 32)
|
||||
|
||||
ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore
|
||||
assert isinstance(ref_init_noise_sigma, Tensor)
|
||||
init_noise_sigma = refiners_scheduler.scale_model_input(tensor(1), step=-1)
|
||||
assert equal(ref_init_noise_sigma, init_noise_sigma), "init_noise_sigma differ"
|
||||
|
||||
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
||||
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
|
||||
refiners_output = refiners_scheduler(x=sample, predicted_noise=predicted_noise, step=step)
|
||||
|
||||
assert equal(diffusers_output, refiners_output), f"outputs differ at step {step}"
|
||||
|
||||
|
||||
def test_lcm_diffusers():
|
||||
from diffusers import LCMScheduler # type: ignore
|
||||
|
||||
|
|
Loading…
Reference in a new issue