diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py index afb6ff2..726bdcb 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py @@ -43,10 +43,13 @@ class DDIM(Scheduler): else tensor(data=[0], device=self.device, dtype=self.dtype) ), ) - current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], ( - self.cumulative_scale_factors[previous_timestep] - if previous_timestep > 0 - else self.cumulative_scale_factors[0] + current_scale_factor, previous_scale_factor = ( + self.cumulative_scale_factors[timestep], + ( + self.cumulative_scale_factors[previous_timestep] + if previous_timestep > 0 + else self.cumulative_scale_factors[0] + ), ) predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise diff --git a/tests/foundationals/latent_diffusion/test_schedulers.py b/tests/foundationals/latent_diffusion/test_schedulers.py index c8a1c98..5f1d9e0 100644 --- a/tests/foundationals/latent_diffusion/test_schedulers.py +++ b/tests/foundationals/latent_diffusion/test_schedulers.py @@ -112,9 +112,7 @@ def test_scheduler_remove_noise(): noise = randn(1, 4, 32, 32) for step, timestep in enumerate(diffusers_scheduler.timesteps): - diffusers_output = cast( - Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample - ) # type: ignore + diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step) assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"