improve FrankenSolver

It now takes a Scheduler factory instead of a Scheduler.
This lets the user potentially recreate the Scheduler on `rebuild`.

It also properly sets the device and dtype on rebuild,
and it has better typing.
This commit is contained in:
Pierre Chapuis 2024-07-17 19:04:02 +02:00
parent 299217f45a
commit daee77298d
2 changed files with 58 additions and 11 deletions

View file

@ -1,10 +1,42 @@
import dataclasses
from typing import Any, cast
from typing import Any, Callable, Protocol, TypeVar
from torch import Generator, Tensor
from torch import Generator, Tensor, device as Device, dtype as DType, float32
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
# Should be Tensor, but some Diffusers schedulers
# are improperly typed as only accepting `int`.
SchedulerTimestepT = Any
class SchedulerOutputLike(Protocol):
@property
def prev_sample(self) -> Tensor: ...
class SchedulerLike(Protocol):
timesteps: Tensor
@property
def init_noise_sigma(self) -> Tensor | float: ...
def set_timesteps(self, num_inference_steps: int, *args: Any, **kwargs: Any) -> None: ...
def scale_model_input(self, sample: Tensor, timestep: SchedulerTimestepT) -> Tensor: ...
def step(
self,
model_output: Tensor,
timestep: SchedulerTimestepT,
sample: Tensor,
*args: Any,
**kwargs: Any,
) -> SchedulerOutputLike | tuple[Any]: ...
TFrankenSolver = TypeVar("TFrankenSolver", bound="FrankenSolver")
class FrankenSolver(Solver):
"""Lets you use Diffusers Schedulers as Refiners Solvers.
@ -14,7 +46,7 @@ class FrankenSolver(Solver):
from refiners.foundationals.latent_diffusion.solvers import FrankenSolver
scheduler = EulerDiscreteScheduler(...)
solver = FrankenSolver(scheduler, num_inference_steps=steps)
solver = FrankenSolver(lambda: scheduler, num_inference_steps=steps)
"""
default_params = dataclasses.replace(
@ -24,27 +56,40 @@ class FrankenSolver(Solver):
def __init__(
self,
diffusers_scheduler: Any,
get_diffusers_scheduler: Callable[[], SchedulerLike],
num_inference_steps: int,
first_inference_step: int = 0,
**kwargs: Any,
device: Device | str = "cpu",
dtype: DType = float32,
**kwargs: Any, # for typing, ignored
) -> None:
self.diffusers_scheduler = diffusers_scheduler
diffusers_scheduler.set_timesteps(num_inference_steps)
super().__init__(num_inference_steps=num_inference_steps, first_inference_step=first_inference_step)
self.get_diffusers_scheduler = get_diffusers_scheduler
self.diffusers_scheduler = self.get_diffusers_scheduler()
self.diffusers_scheduler.set_timesteps(num_inference_steps)
super().__init__(
num_inference_steps=num_inference_steps,
first_inference_step=first_inference_step,
device=device,
dtype=dtype,
)
def _generate_timesteps(self) -> Tensor:
return self.diffusers_scheduler.timesteps
def to(self: TFrankenSolver, device: Device | str | None = None, dtype: DType | None = None) -> TFrankenSolver:
return super().to(device=device, dtype=dtype) # type: ignore
def rebuild(
self,
num_inference_steps: int | None,
first_inference_step: int | None = None,
) -> "FrankenSolver":
return self.__class__(
diffusers_scheduler=self.diffusers_scheduler,
get_diffusers_scheduler=self.get_diffusers_scheduler,
num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,
first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,
device=self.device,
dtype=self.dtype,
)
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
@ -54,4 +99,6 @@ class FrankenSolver(Solver):
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
timestep = self.timesteps[step]
return cast(Tensor, self.diffusers_scheduler.step(predicted_noise, timestep, x).prev_sample)
r = self.diffusers_scheduler.step(predicted_noise, timestep, x)
assert not isinstance(r, tuple), "scheduler returned a tuple"
return r.prev_sample

View file

@ -138,7 +138,7 @@ def test_franken_diffusers():
diffusers_scheduler.set_timesteps(30)
diffusers_scheduler_2 = EulerDiscreteScheduler(**params) # type: ignore
solver = FrankenSolver(diffusers_scheduler_2, num_inference_steps=30)
solver = FrankenSolver(lambda: diffusers_scheduler_2, num_inference_steps=30)
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32)