mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add missing constructor arguments to DDPM scheduler
This commit is contained in:
parent
40c33b9595
commit
a5c665462a
|
@ -1,6 +1,6 @@
|
||||||
from torch import Generator, Tensor, arange, device as Device
|
from torch import Generator, Tensor, arange, device as Device, dtype as DType
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||||
|
|
||||||
|
|
||||||
class DDPM(Scheduler):
|
class DDPM(Scheduler):
|
||||||
|
@ -16,8 +16,10 @@ class DDPM(Scheduler):
|
||||||
num_train_timesteps: int = 1_000,
|
num_train_timesteps: int = 1_000,
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
|
noise_schedule: NoiseSchedule | None = None, # ignored
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType | None = None, # ignored
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
|
|
Loading…
Reference in a new issue