ruff format

This commit is contained in:
Cédric Deltheil 2024-01-10 11:41:47 +01:00 committed by Cédric Deltheil
parent 8423c5efa7
commit ad143b0867
2 changed files with 8 additions and 7 deletions

View file

@ -43,10 +43,13 @@ class DDIM(Scheduler):
else tensor(data=[0], device=self.device, dtype=self.dtype)
),
)
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], (
current_scale_factor, previous_scale_factor = (
self.cumulative_scale_factors[timestep],
(
self.cumulative_scale_factors[previous_timestep]
if previous_timestep > 0
else self.cumulative_scale_factors[0]
),
)
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise

View file

@ -112,9 +112,7 @@ def test_scheduler_remove_noise():
noise = randn(1, 4, 32, 32)
for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(
Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample
) # type: ignore
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore
refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step)
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"