mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
improve FrankenSolver
It now takes a Scheduler factory instead of a Scheduler. This lets the user potentially recreate the Scheduler on `rebuild`. It also properly sets the device and dtype on rebuild, and it has better typing.
This commit is contained in:
parent
299217f45a
commit
daee77298d
|
@ -1,10 +1,42 @@
|
|||
import dataclasses
|
||||
from typing import Any, cast
|
||||
from typing import Any, Callable, Protocol, TypeVar
|
||||
|
||||
from torch import Generator, Tensor
|
||||
from torch import Generator, Tensor, device as Device, dtype as DType, float32
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
|
||||
|
||||
# Should be Tensor, but some Diffusers schedulers
|
||||
# are improperly typed as only accepting `int`.
|
||||
SchedulerTimestepT = Any
|
||||
|
||||
|
||||
class SchedulerOutputLike(Protocol):
|
||||
@property
|
||||
def prev_sample(self) -> Tensor: ...
|
||||
|
||||
|
||||
class SchedulerLike(Protocol):
|
||||
timesteps: Tensor
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self) -> Tensor | float: ...
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, *args: Any, **kwargs: Any) -> None: ...
|
||||
|
||||
def scale_model_input(self, sample: Tensor, timestep: SchedulerTimestepT) -> Tensor: ...
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Tensor,
|
||||
timestep: SchedulerTimestepT,
|
||||
sample: Tensor,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> SchedulerOutputLike | tuple[Any]: ...
|
||||
|
||||
|
||||
TFrankenSolver = TypeVar("TFrankenSolver", bound="FrankenSolver")
|
||||
|
||||
|
||||
class FrankenSolver(Solver):
|
||||
"""Lets you use Diffusers Schedulers as Refiners Solvers.
|
||||
|
@ -14,7 +46,7 @@ class FrankenSolver(Solver):
|
|||
from refiners.foundationals.latent_diffusion.solvers import FrankenSolver
|
||||
|
||||
scheduler = EulerDiscreteScheduler(...)
|
||||
solver = FrankenSolver(scheduler, num_inference_steps=steps)
|
||||
solver = FrankenSolver(lambda: scheduler, num_inference_steps=steps)
|
||||
"""
|
||||
|
||||
default_params = dataclasses.replace(
|
||||
|
@ -24,27 +56,40 @@ class FrankenSolver(Solver):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
diffusers_scheduler: Any,
|
||||
get_diffusers_scheduler: Callable[[], SchedulerLike],
|
||||
num_inference_steps: int,
|
||||
first_inference_step: int = 0,
|
||||
**kwargs: Any,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = float32,
|
||||
**kwargs: Any, # for typing, ignored
|
||||
) -> None:
|
||||
self.diffusers_scheduler = diffusers_scheduler
|
||||
diffusers_scheduler.set_timesteps(num_inference_steps)
|
||||
super().__init__(num_inference_steps=num_inference_steps, first_inference_step=first_inference_step)
|
||||
self.get_diffusers_scheduler = get_diffusers_scheduler
|
||||
self.diffusers_scheduler = self.get_diffusers_scheduler()
|
||||
self.diffusers_scheduler.set_timesteps(num_inference_steps)
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
first_inference_step=first_inference_step,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
return self.diffusers_scheduler.timesteps
|
||||
|
||||
def to(self: TFrankenSolver, device: Device | str | None = None, dtype: DType | None = None) -> TFrankenSolver:
|
||||
return super().to(device=device, dtype=dtype) # type: ignore
|
||||
|
||||
def rebuild(
|
||||
self,
|
||||
num_inference_steps: int | None,
|
||||
first_inference_step: int | None = None,
|
||||
) -> "FrankenSolver":
|
||||
return self.__class__(
|
||||
diffusers_scheduler=self.diffusers_scheduler,
|
||||
get_diffusers_scheduler=self.get_diffusers_scheduler,
|
||||
num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,
|
||||
first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||
|
@ -54,4 +99,6 @@ class FrankenSolver(Solver):
|
|||
|
||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
timestep = self.timesteps[step]
|
||||
return cast(Tensor, self.diffusers_scheduler.step(predicted_noise, timestep, x).prev_sample)
|
||||
r = self.diffusers_scheduler.step(predicted_noise, timestep, x)
|
||||
assert not isinstance(r, tuple), "scheduler returned a tuple"
|
||||
return r.prev_sample
|
||||
|
|
|
@ -138,7 +138,7 @@ def test_franken_diffusers():
|
|||
diffusers_scheduler.set_timesteps(30)
|
||||
|
||||
diffusers_scheduler_2 = EulerDiscreteScheduler(**params) # type: ignore
|
||||
solver = FrankenSolver(diffusers_scheduler_2, num_inference_steps=30)
|
||||
solver = FrankenSolver(lambda: diffusers_scheduler_2, num_inference_steps=30)
|
||||
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
|
||||
|
||||
sample = randn(1, 4, 32, 32)
|
||||
|
|
Loading…
Reference in a new issue