mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
fix ddpm and ddim __init__
This commit is contained in:
parent
ad8f02e555
commit
1075ea4a62
|
@ -14,10 +14,10 @@ class DDIM(Scheduler):
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
num_train_timesteps,
|
num_train_timesteps=num_train_timesteps,
|
||||||
initial_diffusion_rate,
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
final_diffusion_rate,
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
noise_schedule=noise_schedule,
|
noise_schedule=noise_schedule,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
|
@ -16,7 +16,13 @@ class DDPM(Scheduler):
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(num_inference_steps, num_train_timesteps, initial_diffusion_rate, final_diffusion_rate, device)
|
super().__init__(
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
def _generate_timesteps(self) -> Tensor:
|
def _generate_timesteps(self) -> Tensor:
|
||||||
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
||||||
|
@ -50,10 +56,13 @@ class DDPM(Scheduler):
|
||||||
else tensor(-(self.num_train_timesteps // self.num_inference_steps), device=self.device)
|
else tensor(-(self.num_train_timesteps // self.num_inference_steps), device=self.device)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
current_cumulative_factor, previous_cumulative_scale_factor = (self.scale_factors.cumprod(0))[timestep], (
|
current_cumulative_factor, previous_cumulative_scale_factor = (
|
||||||
(self.scale_factors.cumprod(0))[previous_timestep]
|
(self.scale_factors.cumprod(0))[timestep],
|
||||||
if step < len(self.timesteps) - 1
|
(
|
||||||
else tensor(1, device=self.device)
|
(self.scale_factors.cumprod(0))[previous_timestep]
|
||||||
|
if step < len(self.timesteps) - 1
|
||||||
|
else tensor(1, device=self.device)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
current_factor = current_cumulative_factor / previous_cumulative_scale_factor
|
current_factor = current_cumulative_factor / previous_cumulative_scale_factor
|
||||||
estimated_denoised_data = (
|
estimated_denoised_data = (
|
||||||
|
|
Loading…
Reference in a new issue