mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
improve FrankenSolver
It now takes a Scheduler factory instead of a Scheduler. This lets the user potentially recreate the Scheduler on `rebuild`. It also properly sets the device and dtype on rebuild, and it has better typing.
This commit is contained in:
parent
299217f45a
commit
daee77298d
|
@ -1,10 +1,42 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, cast
|
from typing import Any, Callable, Protocol, TypeVar
|
||||||
|
|
||||||
from torch import Generator, Tensor
|
from torch import Generator, Tensor, device as Device, dtype as DType, float32
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
|
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
|
||||||
|
|
||||||
|
# Should be Tensor, but some Diffusers schedulers
|
||||||
|
# are improperly typed as only accepting `int`.
|
||||||
|
SchedulerTimestepT = Any
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerOutputLike(Protocol):
|
||||||
|
@property
|
||||||
|
def prev_sample(self) -> Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerLike(Protocol):
|
||||||
|
timesteps: Tensor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def init_noise_sigma(self) -> Tensor | float: ...
|
||||||
|
|
||||||
|
def set_timesteps(self, num_inference_steps: int, *args: Any, **kwargs: Any) -> None: ...
|
||||||
|
|
||||||
|
def scale_model_input(self, sample: Tensor, timestep: SchedulerTimestepT) -> Tensor: ...
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
model_output: Tensor,
|
||||||
|
timestep: SchedulerTimestepT,
|
||||||
|
sample: Tensor,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> SchedulerOutputLike | tuple[Any]: ...
|
||||||
|
|
||||||
|
|
||||||
|
TFrankenSolver = TypeVar("TFrankenSolver", bound="FrankenSolver")
|
||||||
|
|
||||||
|
|
||||||
class FrankenSolver(Solver):
|
class FrankenSolver(Solver):
|
||||||
"""Lets you use Diffusers Schedulers as Refiners Solvers.
|
"""Lets you use Diffusers Schedulers as Refiners Solvers.
|
||||||
|
@ -14,7 +46,7 @@ class FrankenSolver(Solver):
|
||||||
from refiners.foundationals.latent_diffusion.solvers import FrankenSolver
|
from refiners.foundationals.latent_diffusion.solvers import FrankenSolver
|
||||||
|
|
||||||
scheduler = EulerDiscreteScheduler(...)
|
scheduler = EulerDiscreteScheduler(...)
|
||||||
solver = FrankenSolver(scheduler, num_inference_steps=steps)
|
solver = FrankenSolver(lambda: scheduler, num_inference_steps=steps)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
default_params = dataclasses.replace(
|
default_params = dataclasses.replace(
|
||||||
|
@ -24,27 +56,40 @@ class FrankenSolver(Solver):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
diffusers_scheduler: Any,
|
get_diffusers_scheduler: Callable[[], SchedulerLike],
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
**kwargs: Any,
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = float32,
|
||||||
|
**kwargs: Any, # for typing, ignored
|
||||||
) -> None:
|
) -> None:
|
||||||
self.diffusers_scheduler = diffusers_scheduler
|
self.get_diffusers_scheduler = get_diffusers_scheduler
|
||||||
diffusers_scheduler.set_timesteps(num_inference_steps)
|
self.diffusers_scheduler = self.get_diffusers_scheduler()
|
||||||
super().__init__(num_inference_steps=num_inference_steps, first_inference_step=first_inference_step)
|
self.diffusers_scheduler.set_timesteps(num_inference_steps)
|
||||||
|
super().__init__(
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
first_inference_step=first_inference_step,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def _generate_timesteps(self) -> Tensor:
|
def _generate_timesteps(self) -> Tensor:
|
||||||
return self.diffusers_scheduler.timesteps
|
return self.diffusers_scheduler.timesteps
|
||||||
|
|
||||||
|
def to(self: TFrankenSolver, device: Device | str | None = None, dtype: DType | None = None) -> TFrankenSolver:
|
||||||
|
return super().to(device=device, dtype=dtype) # type: ignore
|
||||||
|
|
||||||
def rebuild(
|
def rebuild(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int | None,
|
num_inference_steps: int | None,
|
||||||
first_inference_step: int | None = None,
|
first_inference_step: int | None = None,
|
||||||
) -> "FrankenSolver":
|
) -> "FrankenSolver":
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
diffusers_scheduler=self.diffusers_scheduler,
|
get_diffusers_scheduler=self.get_diffusers_scheduler,
|
||||||
num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,
|
num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,
|
||||||
first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,
|
first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||||
|
@ -54,4 +99,6 @@ class FrankenSolver(Solver):
|
||||||
|
|
||||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
timestep = self.timesteps[step]
|
timestep = self.timesteps[step]
|
||||||
return cast(Tensor, self.diffusers_scheduler.step(predicted_noise, timestep, x).prev_sample)
|
r = self.diffusers_scheduler.step(predicted_noise, timestep, x)
|
||||||
|
assert not isinstance(r, tuple), "scheduler returned a tuple"
|
||||||
|
return r.prev_sample
|
||||||
|
|
|
@ -138,7 +138,7 @@ def test_franken_diffusers():
|
||||||
diffusers_scheduler.set_timesteps(30)
|
diffusers_scheduler.set_timesteps(30)
|
||||||
|
|
||||||
diffusers_scheduler_2 = EulerDiscreteScheduler(**params) # type: ignore
|
diffusers_scheduler_2 = EulerDiscreteScheduler(**params) # type: ignore
|
||||||
solver = FrankenSolver(diffusers_scheduler_2, num_inference_steps=30)
|
solver = FrankenSolver(lambda: diffusers_scheduler_2, num_inference_steps=30)
|
||||||
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
|
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
|
||||||
|
|
||||||
sample = randn(1, 4, 32, 32)
|
sample = randn(1, 4, 32, 32)
|
||||||
|
|
Loading…
Reference in a new issue