mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add FrankenSolver
This solver is designed to use Diffusers Schedulers as Refiners Solvers.
This commit is contained in:
parent
e091788b88
commit
9e8c2a3753
|
@ -2,6 +2,7 @@ from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
|
||||||
from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
|
from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
|
||||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||||
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
|
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
|
||||||
|
from refiners.foundationals.latent_diffusion.solvers.franken import FrankenSolver
|
||||||
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
|
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
ModelPredictionType,
|
ModelPredictionType,
|
||||||
|
@ -18,6 +19,7 @@ __all__ = [
|
||||||
"DDPM",
|
"DDPM",
|
||||||
"DDIM",
|
"DDIM",
|
||||||
"Euler",
|
"Euler",
|
||||||
|
"FrankenSolver",
|
||||||
"LCMSolver",
|
"LCMSolver",
|
||||||
"ModelPredictionType",
|
"ModelPredictionType",
|
||||||
"NoiseSchedule",
|
"NoiseSchedule",
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
import dataclasses
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from torch import Generator, Tensor
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
|
||||||
|
|
||||||
|
|
||||||
|
class FrankenSolver(Solver):
|
||||||
|
"""Lets you use Diffusers Schedulers as Refiners Solvers.
|
||||||
|
|
||||||
|
For instance:
|
||||||
|
from diffusers import EulerDiscreteScheduler
|
||||||
|
from refiners.foundationals.latent_diffusion.solvers import FrankenSolver
|
||||||
|
|
||||||
|
scheduler = EulerDiscreteScheduler(...)
|
||||||
|
solver = FrankenSolver(scheduler, num_inference_steps=steps)
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_params = dataclasses.replace(
|
||||||
|
Solver.default_params,
|
||||||
|
timesteps_spacing=TimestepSpacing.CUSTOM,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
diffusers_scheduler: Any,
|
||||||
|
num_inference_steps: int,
|
||||||
|
first_inference_step: int = 0,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self.diffusers_scheduler = diffusers_scheduler
|
||||||
|
diffusers_scheduler.set_timesteps(num_inference_steps)
|
||||||
|
super().__init__(num_inference_steps=num_inference_steps, first_inference_step=first_inference_step)
|
||||||
|
|
||||||
|
def _generate_timesteps(self) -> Tensor:
|
||||||
|
return self.diffusers_scheduler.timesteps
|
||||||
|
|
||||||
|
def rebuild(
|
||||||
|
self,
|
||||||
|
num_inference_steps: int | None,
|
||||||
|
first_inference_step: int | None = None,
|
||||||
|
) -> "FrankenSolver":
|
||||||
|
return self.__class__(
|
||||||
|
diffusers_scheduler=self.diffusers_scheduler,
|
||||||
|
num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,
|
||||||
|
first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||||
|
if step == -1:
|
||||||
|
return x * self.diffusers_scheduler.init_noise_sigma
|
||||||
|
return self.diffusers_scheduler.scale_model_input(x, self.timesteps[step])
|
||||||
|
|
||||||
|
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
|
timestep = self.timesteps[step]
|
||||||
|
return cast(Tensor, self.diffusers_scheduler.step(predicted_noise, timestep, x).prev_sample)
|
|
@ -10,6 +10,7 @@ from refiners.foundationals.latent_diffusion.solvers import (
|
||||||
DDPM,
|
DDPM,
|
||||||
DPMSolver,
|
DPMSolver,
|
||||||
Euler,
|
Euler,
|
||||||
|
FrankenSolver,
|
||||||
LCMSolver,
|
LCMSolver,
|
||||||
ModelPredictionType,
|
ModelPredictionType,
|
||||||
NoiseSchedule,
|
NoiseSchedule,
|
||||||
|
@ -119,6 +120,42 @@ def test_euler_diffusers(model_prediction_type: ModelPredictionType):
|
||||||
assert allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}"
|
assert allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_franken_diffusers():
|
||||||
|
from diffusers import EulerDiscreteScheduler # type: ignore
|
||||||
|
|
||||||
|
manual_seed(0)
|
||||||
|
params = {
|
||||||
|
"beta_end": 0.012,
|
||||||
|
"beta_schedule": "scaled_linear",
|
||||||
|
"beta_start": 0.00085,
|
||||||
|
"num_train_timesteps": 1000,
|
||||||
|
"steps_offset": 1,
|
||||||
|
"timestep_spacing": "linspace",
|
||||||
|
"use_karras_sigmas": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
diffusers_scheduler = EulerDiscreteScheduler(**params) # type: ignore
|
||||||
|
diffusers_scheduler.set_timesteps(30)
|
||||||
|
|
||||||
|
diffusers_scheduler_2 = EulerDiscreteScheduler(**params) # type: ignore
|
||||||
|
refiners_scheduler = FrankenSolver(diffusers_scheduler_2, num_inference_steps=30)
|
||||||
|
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||||
|
|
||||||
|
sample = randn(1, 4, 32, 32)
|
||||||
|
predicted_noise = randn(1, 4, 32, 32)
|
||||||
|
|
||||||
|
ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore
|
||||||
|
assert isinstance(ref_init_noise_sigma, Tensor)
|
||||||
|
init_noise_sigma = refiners_scheduler.scale_model_input(tensor(1), step=-1)
|
||||||
|
assert equal(ref_init_noise_sigma, init_noise_sigma), "init_noise_sigma differ"
|
||||||
|
|
||||||
|
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
||||||
|
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
|
||||||
|
refiners_output = refiners_scheduler(x=sample, predicted_noise=predicted_noise, step=step)
|
||||||
|
|
||||||
|
assert equal(diffusers_output, refiners_output), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
def test_lcm_diffusers():
|
def test_lcm_diffusers():
|
||||||
from diffusers import LCMScheduler # type: ignore
|
from diffusers import LCMScheduler # type: ignore
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue