improve FrankenSolver

It now takes a Scheduler factory instead of a Scheduler.
This lets the user potentially recreate the Scheduler on `rebuild`.

It also properly sets the device and dtype on rebuild,
and it has better typing.
This commit is contained in:
Pierre Chapuis 2024-07-17 19:04:02 +02:00
parent 299217f45a
commit daee77298d
2 changed files with 58 additions and 11 deletions

View file

@ -1,10 +1,42 @@
import dataclasses import dataclasses
from typing import Any, cast from typing import Any, Callable, Protocol, TypeVar
from torch import Generator, Tensor from torch import Generator, Tensor, device as Device, dtype as DType, float32
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
# Should be Tensor, but some Diffusers schedulers
# are improperly typed as only accepting `int`.
SchedulerTimestepT = Any
class SchedulerOutputLike(Protocol):
@property
def prev_sample(self) -> Tensor: ...
class SchedulerLike(Protocol):
timesteps: Tensor
@property
def init_noise_sigma(self) -> Tensor | float: ...
def set_timesteps(self, num_inference_steps: int, *args: Any, **kwargs: Any) -> None: ...
def scale_model_input(self, sample: Tensor, timestep: SchedulerTimestepT) -> Tensor: ...
def step(
self,
model_output: Tensor,
timestep: SchedulerTimestepT,
sample: Tensor,
*args: Any,
**kwargs: Any,
) -> SchedulerOutputLike | tuple[Any]: ...
TFrankenSolver = TypeVar("TFrankenSolver", bound="FrankenSolver")
class FrankenSolver(Solver): class FrankenSolver(Solver):
"""Lets you use Diffusers Schedulers as Refiners Solvers. """Lets you use Diffusers Schedulers as Refiners Solvers.
@ -14,7 +46,7 @@ class FrankenSolver(Solver):
from refiners.foundationals.latent_diffusion.solvers import FrankenSolver from refiners.foundationals.latent_diffusion.solvers import FrankenSolver
scheduler = EulerDiscreteScheduler(...) scheduler = EulerDiscreteScheduler(...)
solver = FrankenSolver(scheduler, num_inference_steps=steps) solver = FrankenSolver(lambda: scheduler, num_inference_steps=steps)
""" """
default_params = dataclasses.replace( default_params = dataclasses.replace(
@ -24,27 +56,40 @@ class FrankenSolver(Solver):
def __init__( def __init__(
self, self,
diffusers_scheduler: Any, get_diffusers_scheduler: Callable[[], SchedulerLike],
num_inference_steps: int, num_inference_steps: int,
first_inference_step: int = 0, first_inference_step: int = 0,
**kwargs: Any, device: Device | str = "cpu",
dtype: DType = float32,
**kwargs: Any, # for typing, ignored
) -> None: ) -> None:
self.diffusers_scheduler = diffusers_scheduler self.get_diffusers_scheduler = get_diffusers_scheduler
diffusers_scheduler.set_timesteps(num_inference_steps) self.diffusers_scheduler = self.get_diffusers_scheduler()
super().__init__(num_inference_steps=num_inference_steps, first_inference_step=first_inference_step) self.diffusers_scheduler.set_timesteps(num_inference_steps)
super().__init__(
num_inference_steps=num_inference_steps,
first_inference_step=first_inference_step,
device=device,
dtype=dtype,
)
def _generate_timesteps(self) -> Tensor: def _generate_timesteps(self) -> Tensor:
return self.diffusers_scheduler.timesteps return self.diffusers_scheduler.timesteps
def to(self: TFrankenSolver, device: Device | str | None = None, dtype: DType | None = None) -> TFrankenSolver:
return super().to(device=device, dtype=dtype) # type: ignore
def rebuild( def rebuild(
self, self,
num_inference_steps: int | None, num_inference_steps: int | None,
first_inference_step: int | None = None, first_inference_step: int | None = None,
) -> "FrankenSolver": ) -> "FrankenSolver":
return self.__class__( return self.__class__(
diffusers_scheduler=self.diffusers_scheduler, get_diffusers_scheduler=self.get_diffusers_scheduler,
num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps, num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,
first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step, first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,
device=self.device,
dtype=self.dtype,
) )
def scale_model_input(self, x: Tensor, step: int) -> Tensor: def scale_model_input(self, x: Tensor, step: int) -> Tensor:
@ -54,4 +99,6 @@ class FrankenSolver(Solver):
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
timestep = self.timesteps[step] timestep = self.timesteps[step]
return cast(Tensor, self.diffusers_scheduler.step(predicted_noise, timestep, x).prev_sample) r = self.diffusers_scheduler.step(predicted_noise, timestep, x)
assert not isinstance(r, tuple), "scheduler returned a tuple"
return r.prev_sample

View file

@ -138,7 +138,7 @@ def test_franken_diffusers():
diffusers_scheduler.set_timesteps(30) diffusers_scheduler.set_timesteps(30)
diffusers_scheduler_2 = EulerDiscreteScheduler(**params) # type: ignore diffusers_scheduler_2 = EulerDiscreteScheduler(**params) # type: ignore
solver = FrankenSolver(diffusers_scheduler_2, num_inference_steps=30) solver = FrankenSolver(lambda: diffusers_scheduler_2, num_inference_steps=30)
assert equal(solver.timesteps, diffusers_scheduler.timesteps) assert equal(solver.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32) sample = randn(1, 4, 32, 32)