From 1075ea4a628b3cfaa96f8c5c5b539164461153ad Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Sun, 3 Dec 2023 17:17:04 +0100 Subject: [PATCH] fix ddpm and ddim __init__ --- .../latent_diffusion/schedulers/ddim.py | 8 ++++---- .../latent_diffusion/schedulers/ddpm.py | 19 ++++++++++++++----- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py index 9c1bf7f..52c6340 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py @@ -14,10 +14,10 @@ class DDIM(Scheduler): dtype: Dtype = float32, ) -> None: super().__init__( - num_inference_steps, - num_train_timesteps, - initial_diffusion_rate, - final_diffusion_rate, + num_inference_steps=num_inference_steps, + num_train_timesteps=num_train_timesteps, + initial_diffusion_rate=initial_diffusion_rate, + final_diffusion_rate=final_diffusion_rate, noise_schedule=noise_schedule, device=device, dtype=dtype, diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py index 528d395..4bf5554 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py @@ -16,7 +16,13 @@ class DDPM(Scheduler): final_diffusion_rate: float = 1.2e-2, device: Device | str = "cpu", ) -> None: - super().__init__(num_inference_steps, num_train_timesteps, initial_diffusion_rate, final_diffusion_rate, device) + super().__init__( + num_inference_steps=num_inference_steps, + num_train_timesteps=num_train_timesteps, + initial_diffusion_rate=initial_diffusion_rate, + final_diffusion_rate=final_diffusion_rate, + device=device, + ) def _generate_timesteps(self) -> Tensor: step_ratio = self.num_train_timesteps // self.num_inference_steps @@ -50,10 +56,13 @@ class DDPM(Scheduler): else tensor(-(self.num_train_timesteps // self.num_inference_steps), device=self.device) ), ) - current_cumulative_factor, previous_cumulative_scale_factor = (self.scale_factors.cumprod(0))[timestep], ( - (self.scale_factors.cumprod(0))[previous_timestep] - if step < len(self.timesteps) - 1 - else tensor(1, device=self.device) + current_cumulative_factor, previous_cumulative_scale_factor = ( + (self.scale_factors.cumprod(0))[timestep], + ( + (self.scale_factors.cumprod(0))[previous_timestep] + if step < len(self.timesteps) - 1 + else tensor(1, device=self.device) + ), ) current_factor = current_cumulative_factor / previous_cumulative_scale_factor estimated_denoised_data = (