add missing constructor arguments to DDPM scheduler

This commit is contained in:
Pierre Chapuis 2024-01-23 09:27:26 +01:00
parent 40c33b9595
commit a5c665462a

View file

@ -1,6 +1,6 @@
from torch import Generator, Tensor, arange, device as Device
from torch import Generator, Tensor, arange, device as Device, dtype as DType
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
class DDPM(Scheduler):
@ -16,8 +16,10 @@ class DDPM(Scheduler):
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule | None = None, # ignored
first_inference_step: int = 0,
device: Device | str = "cpu",
dtype: DType | None = None, # ignored
) -> None:
super().__init__(
num_inference_steps=num_inference_steps,