From 2b2b6740b7e780734398603c6a678c8fb065ac3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Wed, 10 Jan 2024 11:45:19 +0100 Subject: [PATCH] fix or silent pyright issues --- .../foundationals/latent_diffusion/schedulers/ddpm.py | 4 ++-- tests/foundationals/latent_diffusion/test_schedulers.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py index a7490dc..52ff5e9 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py @@ -1,4 +1,4 @@ -from torch import Tensor, arange, device as Device +from torch import Generator, Tensor, arange, device as Device from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler @@ -30,5 +30,5 @@ class DDPM(Scheduler): timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio return timesteps.flip(0) - def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor: + def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: raise NotImplementedError diff --git a/tests/foundationals/latent_diffusion/test_schedulers.py b/tests/foundationals/latent_diffusion/test_schedulers.py index 4b5841f..74ec637 100644 --- a/tests/foundationals/latent_diffusion/test_schedulers.py +++ b/tests/foundationals/latent_diffusion/test_schedulers.py @@ -64,7 +64,7 @@ def test_ddim_diffusers(): def test_euler_diffusers(): - from diffusers import EulerDiscreteScheduler + from diffusers import EulerDiscreteScheduler # type: ignore manual_seed(0) diffusers_scheduler = EulerDiscreteScheduler(