diff --git a/src/refiners/foundationals/latent_diffusion/solvers/franken.py b/src/refiners/foundationals/latent_diffusion/solvers/franken.py index 2842e67..b661c6c 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/franken.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/franken.py @@ -1,10 +1,42 @@ import dataclasses -from typing import Any, cast +from typing import Any, Callable, Protocol, TypeVar -from torch import Generator, Tensor +from torch import Generator, Tensor, device as Device, dtype as DType, float32 from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing +# Should be Tensor, but some Diffusers schedulers +# are improperly typed as only accepting `int`. +SchedulerTimestepT = Any + + +class SchedulerOutputLike(Protocol): + @property + def prev_sample(self) -> Tensor: ... + + +class SchedulerLike(Protocol): + timesteps: Tensor + + @property + def init_noise_sigma(self) -> Tensor | float: ... + + def set_timesteps(self, num_inference_steps: int, *args: Any, **kwargs: Any) -> None: ... + + def scale_model_input(self, sample: Tensor, timestep: SchedulerTimestepT) -> Tensor: ... + + def step( + self, + model_output: Tensor, + timestep: SchedulerTimestepT, + sample: Tensor, + *args: Any, + **kwargs: Any, + ) -> SchedulerOutputLike | tuple[Any]: ... + + +TFrankenSolver = TypeVar("TFrankenSolver", bound="FrankenSolver") + class FrankenSolver(Solver): """Lets you use Diffusers Schedulers as Refiners Solvers. @@ -14,7 +46,7 @@ class FrankenSolver(Solver): from refiners.foundationals.latent_diffusion.solvers import FrankenSolver scheduler = EulerDiscreteScheduler(...) - solver = FrankenSolver(scheduler, num_inference_steps=steps) + solver = FrankenSolver(lambda: scheduler, num_inference_steps=steps) """ default_params = dataclasses.replace( @@ -24,27 +56,40 @@ class FrankenSolver(Solver): def __init__( self, - diffusers_scheduler: Any, + get_diffusers_scheduler: Callable[[], SchedulerLike], num_inference_steps: int, first_inference_step: int = 0, - **kwargs: Any, + device: Device | str = "cpu", + dtype: DType = float32, + **kwargs: Any, # for typing, ignored ) -> None: - self.diffusers_scheduler = diffusers_scheduler - diffusers_scheduler.set_timesteps(num_inference_steps) - super().__init__(num_inference_steps=num_inference_steps, first_inference_step=first_inference_step) + self.get_diffusers_scheduler = get_diffusers_scheduler + self.diffusers_scheduler = self.get_diffusers_scheduler() + self.diffusers_scheduler.set_timesteps(num_inference_steps) + super().__init__( + num_inference_steps=num_inference_steps, + first_inference_step=first_inference_step, + device=device, + dtype=dtype, + ) def _generate_timesteps(self) -> Tensor: return self.diffusers_scheduler.timesteps + def to(self: TFrankenSolver, device: Device | str | None = None, dtype: DType | None = None) -> TFrankenSolver: + return super().to(device=device, dtype=dtype) # type: ignore + def rebuild( self, num_inference_steps: int | None, first_inference_step: int | None = None, ) -> "FrankenSolver": return self.__class__( - diffusers_scheduler=self.diffusers_scheduler, + get_diffusers_scheduler=self.get_diffusers_scheduler, num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps, first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step, + device=self.device, + dtype=self.dtype, ) def scale_model_input(self, x: Tensor, step: int) -> Tensor: @@ -54,4 +99,6 @@ class FrankenSolver(Solver): def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: timestep = self.timesteps[step] - return cast(Tensor, self.diffusers_scheduler.step(predicted_noise, timestep, x).prev_sample) + r = self.diffusers_scheduler.step(predicted_noise, timestep, x) + assert not isinstance(r, tuple), "scheduler returned a tuple" + return r.prev_sample diff --git a/tests/foundationals/latent_diffusion/test_solvers.py b/tests/foundationals/latent_diffusion/test_solvers.py index 540f319..f9888e1 100644 --- a/tests/foundationals/latent_diffusion/test_solvers.py +++ b/tests/foundationals/latent_diffusion/test_solvers.py @@ -138,7 +138,7 @@ def test_franken_diffusers(): diffusers_scheduler.set_timesteps(30) diffusers_scheduler_2 = EulerDiscreteScheduler(**params) # type: ignore - solver = FrankenSolver(diffusers_scheduler_2, num_inference_steps=30) + solver = FrankenSolver(lambda: diffusers_scheduler_2, num_inference_steps=30) assert equal(solver.timesteps, diffusers_scheduler.timesteps) sample = randn(1, 4, 32, 32)