diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py index 2fb5952..dd1822b 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py @@ -1,6 +1,6 @@ -from torch import Generator, Tensor, arange, device as Device +from torch import Generator, Tensor, arange, device as Device, dtype as DType -from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler +from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler class DDPM(Scheduler): @@ -16,8 +16,10 @@ class DDPM(Scheduler): num_train_timesteps: int = 1_000, initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, + noise_schedule: NoiseSchedule | None = None, # ignored first_inference_step: int = 0, device: Device | str = "cpu", + dtype: DType | None = None, # ignored ) -> None: super().__init__( num_inference_steps=num_inference_steps,