From a5c665462a8c5d942e969d2ecda15fb170735ef8 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Tue, 23 Jan 2024 09:27:26 +0100 Subject: [PATCH] add missing constructor arguments to DDPM scheduler --- .../foundationals/latent_diffusion/schedulers/ddpm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py index 2fb5952..dd1822b 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py @@ -1,6 +1,6 @@ -from torch import Generator, Tensor, arange, device as Device +from torch import Generator, Tensor, arange, device as Device, dtype as DType -from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler +from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler class DDPM(Scheduler): @@ -16,8 +16,10 @@ class DDPM(Scheduler): num_train_timesteps: int = 1_000, initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, + noise_schedule: NoiseSchedule | None = None, # ignored first_inference_step: int = 0, device: Device | str = "cpu", + dtype: DType | None = None, # ignored ) -> None: super().__init__( num_inference_steps=num_inference_steps,