fix ddpm and ddim __init__

This commit is contained in:
limiteinductive 2023-12-03 17:17:04 +01:00 committed by Benjamin Trom
parent ad8f02e555
commit 1075ea4a62
2 changed files with 18 additions and 9 deletions

View file

@ -14,10 +14,10 @@ class DDIM(Scheduler):
dtype: Dtype = float32, dtype: Dtype = float32,
) -> None: ) -> None:
super().__init__( super().__init__(
num_inference_steps, num_inference_steps=num_inference_steps,
num_train_timesteps, num_train_timesteps=num_train_timesteps,
initial_diffusion_rate, initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate, final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule, noise_schedule=noise_schedule,
device=device, device=device,
dtype=dtype, dtype=dtype,

View file

@ -16,7 +16,13 @@ class DDPM(Scheduler):
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
device: Device | str = "cpu", device: Device | str = "cpu",
) -> None: ) -> None:
super().__init__(num_inference_steps, num_train_timesteps, initial_diffusion_rate, final_diffusion_rate, device) super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
device=device,
)
def _generate_timesteps(self) -> Tensor: def _generate_timesteps(self) -> Tensor:
step_ratio = self.num_train_timesteps // self.num_inference_steps step_ratio = self.num_train_timesteps // self.num_inference_steps
@ -50,10 +56,13 @@ class DDPM(Scheduler):
else tensor(-(self.num_train_timesteps // self.num_inference_steps), device=self.device) else tensor(-(self.num_train_timesteps // self.num_inference_steps), device=self.device)
), ),
) )
current_cumulative_factor, previous_cumulative_scale_factor = (self.scale_factors.cumprod(0))[timestep], ( current_cumulative_factor, previous_cumulative_scale_factor = (
(self.scale_factors.cumprod(0))[previous_timestep] (self.scale_factors.cumprod(0))[timestep],
if step < len(self.timesteps) - 1 (
else tensor(1, device=self.device) (self.scale_factors.cumprod(0))[previous_timestep]
if step < len(self.timesteps) - 1
else tensor(1, device=self.device)
),
) )
current_factor = current_cumulative_factor / previous_cumulative_scale_factor current_factor = current_cumulative_factor / previous_cumulative_scale_factor
estimated_denoised_data = ( estimated_denoised_data = (