mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
fix ddpm and ddim __init__
This commit is contained in:
parent
ad8f02e555
commit
1075ea4a62
|
@ -14,10 +14,10 @@ class DDIM(Scheduler):
|
|||
dtype: Dtype = float32,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_inference_steps,
|
||||
num_train_timesteps,
|
||||
initial_diffusion_rate,
|
||||
final_diffusion_rate,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
|
|
|
@ -16,7 +16,13 @@ class DDPM(Scheduler):
|
|||
final_diffusion_rate: float = 1.2e-2,
|
||||
device: Device | str = "cpu",
|
||||
) -> None:
|
||||
super().__init__(num_inference_steps, num_train_timesteps, initial_diffusion_rate, final_diffusion_rate, device)
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
||||
|
@ -50,10 +56,13 @@ class DDPM(Scheduler):
|
|||
else tensor(-(self.num_train_timesteps // self.num_inference_steps), device=self.device)
|
||||
),
|
||||
)
|
||||
current_cumulative_factor, previous_cumulative_scale_factor = (self.scale_factors.cumprod(0))[timestep], (
|
||||
(self.scale_factors.cumprod(0))[previous_timestep]
|
||||
if step < len(self.timesteps) - 1
|
||||
else tensor(1, device=self.device)
|
||||
current_cumulative_factor, previous_cumulative_scale_factor = (
|
||||
(self.scale_factors.cumprod(0))[timestep],
|
||||
(
|
||||
(self.scale_factors.cumprod(0))[previous_timestep]
|
||||
if step < len(self.timesteps) - 1
|
||||
else tensor(1, device=self.device)
|
||||
),
|
||||
)
|
||||
current_factor = current_cumulative_factor / previous_cumulative_scale_factor
|
||||
estimated_denoised_data = (
|
||||
|
|
Loading…
Reference in a new issue