diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py index a7490dc..52ff5e9 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py @@ -1,4 +1,4 @@ -from torch import Tensor, arange, device as Device +from torch import Generator, Tensor, arange, device as Device from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler @@ -30,5 +30,5 @@ class DDPM(Scheduler): timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio return timesteps.flip(0) - def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor: + def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: raise NotImplementedError diff --git a/tests/foundationals/latent_diffusion/test_schedulers.py b/tests/foundationals/latent_diffusion/test_schedulers.py index 4b5841f..74ec637 100644 --- a/tests/foundationals/latent_diffusion/test_schedulers.py +++ b/tests/foundationals/latent_diffusion/test_schedulers.py @@ -64,7 +64,7 @@ def test_ddim_diffusers(): def test_euler_diffusers(): - from diffusers import EulerDiscreteScheduler + from diffusers import EulerDiscreteScheduler # type: ignore manual_seed(0) diffusers_scheduler = EulerDiscreteScheduler(