mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
ruff format
This commit is contained in:
parent
8423c5efa7
commit
ad143b0867
|
@ -43,10 +43,13 @@ class DDIM(Scheduler):
|
||||||
else tensor(data=[0], device=self.device, dtype=self.dtype)
|
else tensor(data=[0], device=self.device, dtype=self.dtype)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], (
|
current_scale_factor, previous_scale_factor = (
|
||||||
|
self.cumulative_scale_factors[timestep],
|
||||||
|
(
|
||||||
self.cumulative_scale_factors[previous_timestep]
|
self.cumulative_scale_factors[previous_timestep]
|
||||||
if previous_timestep > 0
|
if previous_timestep > 0
|
||||||
else self.cumulative_scale_factors[0]
|
else self.cumulative_scale_factors[0]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
|
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
|
||||||
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
|
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
|
||||||
|
|
|
@ -112,9 +112,7 @@ def test_scheduler_remove_noise():
|
||||||
noise = randn(1, 4, 32, 32)
|
noise = randn(1, 4, 32, 32)
|
||||||
|
|
||||||
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
||||||
diffusers_output = cast(
|
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore
|
||||||
Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample
|
|
||||||
) # type: ignore
|
|
||||||
refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step)
|
refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step)
|
||||||
|
|
||||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||||
|
|
Loading…
Reference in a new issue