add karras sigmas to dpm solver

This commit is contained in:
limiteinductive 2024-09-06 10:56:24 +00:00 committed by Benjamin Trom
parent 5aef1408d8
commit af6c5aecbe
6 changed files with 215 additions and 88 deletions

View file

@ -1,18 +1,35 @@
import dataclasses import dataclasses
from collections import deque from collections import deque
from typing import NamedTuple
import numpy as np import numpy as np
import torch import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype
from refiners.foundationals.latent_diffusion.solvers.solver import ( from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams, BaseSolverParams,
ModelPredictionType, ModelPredictionType,
NoiseSchedule,
Solver, Solver,
TimestepSpacing, TimestepSpacing,
) )
def safe_log(x: torch.Tensor, lower_bound: float = 1e-6) -> torch.Tensor:
"""Compute the log of a tensor with a lower bound."""
return torch.log(torch.maximum(x, torch.tensor(lower_bound)))
def safe_sqrt(x: torch.Tensor) -> torch.Tensor:
"""Compute the square root of a tensor ensuring that the input is non-negative"""
return torch.sqrt(torch.maximum(x, torch.tensor(0)))
class SolverTensors(NamedTuple):
cumulative_scale_factors: torch.Tensor
noise_std: torch.Tensor
signal_to_noise_ratios: torch.Tensor
class DPMSolver(Solver): class DPMSolver(Solver):
"""Diffusion probabilistic models (DPMs) solver. """Diffusion probabilistic models (DPMs) solver.
@ -37,9 +54,9 @@ class DPMSolver(Solver):
first_inference_step: int = 0, first_inference_step: int = 0,
params: BaseSolverParams | None = None, params: BaseSolverParams | None = None,
last_step_first_order: bool = False, last_step_first_order: bool = False,
device: Device | str = "cpu", device: torch.device | str = "cpu",
dtype: Dtype = torch.float32, dtype: torch.dtype = torch.float32,
): ) -> None:
"""Initializes a new DPM solver. """Initializes a new DPM solver.
Args: Args:
@ -64,6 +81,14 @@ class DPMSolver(Solver):
) )
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2) self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
self.last_step_first_order = last_step_first_order self.last_step_first_order = last_step_first_order
sigmas = self.noise_std / self.cumulative_scale_factors
self.sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)
sigma_min = sigmas[0:1] # corresponds to `final_sigmas_type="sigma_min" in diffusers`
self.sigmas = torch.cat([self.sigmas, sigma_min])
self.cumulative_scale_factors, self.noise_std, self.signal_to_noise_ratios = self._solver_tensors_from_sigmas(
self.sigmas
)
self.timesteps = self._timesteps_from_sigmas(sigmas)
def rebuild( def rebuild(
self: "DPMSolver", self: "DPMSolver",
@ -83,7 +108,7 @@ class DPMSolver(Solver):
r.last_step_first_order = self.last_step_first_order r.last_step_first_order = self.last_step_first_order
return r return r
def _generate_timesteps(self) -> Tensor: def _generate_timesteps(self) -> torch.Tensor:
if self.params.timesteps_spacing != TimestepSpacing.CUSTOM: if self.params.timesteps_spacing != TimestepSpacing.CUSTOM:
return super()._generate_timesteps() return super()._generate_timesteps()
@ -96,9 +121,75 @@ class DPMSolver(Solver):
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:] np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
return torch.tensor(np_space).flip(0) return torch.tensor(np_space).flip(0)
def _generate_sigmas(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Generate the sigmas used by the solver."""
assert self.params.sigma_schedule is not None, "sigma_schedule must be set for the DPM solver"
sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = sigmas.flip(0)
rescaled_sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)
rescaled_sigmas = torch.cat([rescaled_sigmas, torch.tensor([0.0])])
return sigmas, rescaled_sigmas
def _rescale_sigmas(self, sigmas: torch.Tensor, sigma_schedule: NoiseSchedule | None) -> torch.Tensor:
"""Rescale the sigmas according to the sigma schedule."""
match sigma_schedule:
case NoiseSchedule.UNIFORM:
rho = 1
case NoiseSchedule.QUADRATIC:
rho = 2
case NoiseSchedule.KARRAS:
rho = 7
case None:
return torch.tensor(
np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()),
device=self.device,
)
linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device)
first_sigma = sigmas[0]
last_sigma = sigmas[-1]
rescaled_sigmas = (
first_sigma ** (1 / rho) + linear_schedule * (last_sigma ** (1 / rho) - first_sigma ** (1 / rho))
) ** rho
return rescaled_sigmas.flip(0)
def _timesteps_from_sigmas(self, sigmas: torch.Tensor) -> torch.Tensor:
"""Generate the timesteps from the sigmas."""
log_sigmas = safe_log(sigmas)
timesteps: list[torch.Tensor] = []
for sigma in self.sigmas[:-1]:
log_sigma = safe_log(sigma)
distance_matrix = log_sigma - log_sigmas.unsqueeze(1)
# Determine the range of sigma indices
low_indices = (distance_matrix >= 0).cumsum(dim=0).argmax(dim=0).clip(max=sigmas.size(0) - 2)
high_indices = low_indices + 1
low_log_sigma = log_sigmas[low_indices]
high_log_sigma = log_sigmas[high_indices]
# Interpolate sigma values
interpolation_weights = (low_log_sigma - log_sigma) / (low_log_sigma - high_log_sigma)
interpolation_weights = torch.clamp(interpolation_weights, 0, 1)
timestep = (1 - interpolation_weights) * low_indices + interpolation_weights * high_indices
timesteps.append(timestep)
return torch.cat(timesteps).round()
def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors:
"""Generate the tensors from the sigmas."""
cumulative_scale_factors = 1 / torch.sqrt(sigmas**2 + 1)
noise_std = sigmas * cumulative_scale_factors
signal_to_noise_ratios = safe_log(cumulative_scale_factors) - safe_log(noise_std)
return SolverTensors(
cumulative_scale_factors=cumulative_scale_factors,
noise_std=noise_std,
signal_to_noise_ratios=signal_to_noise_ratios,
)
def dpm_solver_first_order_update( def dpm_solver_first_order_update(
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None self, x: torch.Tensor, noise: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None
) -> Tensor: ) -> torch.Tensor:
"""Applies a first-order backward Euler update to the input data `x`. """Applies a first-order backward Euler update to the input data `x`.
Args: Args:
@ -109,32 +200,29 @@ class DPMSolver(Solver):
Returns: Returns:
The denoised version of the input data `x`. The denoised version of the input data `x`.
""" """
current_timestep = self.timesteps[step] current_ratio = self.signal_to_noise_ratios[step]
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0]) next_ratio = self.signal_to_noise_ratios[step + 1]
previous_ratio = self.signal_to_noise_ratios[previous_timestep] next_scale_factor = self.cumulative_scale_factors[step + 1]
current_ratio = self.signal_to_noise_ratios[current_timestep]
previous_scale_factor = self.cumulative_scale_factors[previous_timestep] next_noise_std = self.noise_std[step + 1]
current_noise_std = self.noise_std[step]
previous_noise_std = self.noise_std[previous_timestep] ratio_delta = current_ratio - next_ratio
current_noise_std = self.noise_std[current_timestep]
ratio_delta = current_ratio - previous_ratio
if sde_noise is None: if sde_noise is None:
return (previous_noise_std / current_noise_std) * x + ( return (next_noise_std / current_noise_std) * x + (1.0 - torch.exp(ratio_delta)) * next_scale_factor * noise
1.0 - torch.exp(ratio_delta)
) * previous_scale_factor * noise
factor = 1.0 - torch.exp(2.0 * ratio_delta) factor = 1.0 - torch.exp(2.0 * ratio_delta)
return ( return (
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x (next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
+ previous_scale_factor * factor * noise + next_scale_factor * factor * noise
+ previous_noise_std * torch.sqrt(factor) * sde_noise + next_noise_std * safe_sqrt(factor) * sde_noise
) )
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noise: Tensor | None = None) -> Tensor: def multistep_dpm_solver_second_order_update(
self, x: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None
) -> torch.Tensor:
"""Applies a second-order backward Euler update to the input data `x`. """Applies a second-order backward Euler update to the input data `x`.
Args: Args:
@ -144,43 +232,41 @@ class DPMSolver(Solver):
Returns: Returns:
The denoised version of the input data `x`. The denoised version of the input data `x`.
""" """
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
current_timestep = self.timesteps[step]
next_timestep = self.timesteps[step - 1]
current_data_estimation = self.estimated_data[-1] current_data_estimation = self.estimated_data[-1]
next_data_estimation = self.estimated_data[-2] previous_data_estimation = self.estimated_data[-2]
previous_ratio = self.signal_to_noise_ratios[previous_timestep] next_ratio = self.signal_to_noise_ratios[step + 1]
current_ratio = self.signal_to_noise_ratios[current_timestep] current_ratio = self.signal_to_noise_ratios[step]
next_ratio = self.signal_to_noise_ratios[next_timestep] previous_ratio = self.signal_to_noise_ratios[step - 1]
previous_scale_factor = self.cumulative_scale_factors[previous_timestep] next_scale_factor = self.cumulative_scale_factors[step + 1]
previous_noise_std = self.noise_std[previous_timestep] next_noise_std = self.noise_std[step + 1]
current_noise_std = self.noise_std[current_timestep] current_noise_std = self.noise_std[step]
estimation_delta = (current_data_estimation - next_data_estimation) / ( estimation_delta = (current_data_estimation - previous_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio) (current_ratio - previous_ratio) / (next_ratio - current_ratio)
) )
ratio_delta = current_ratio - previous_ratio ratio_delta = current_ratio - next_ratio
if sde_noise is None: if sde_noise is None:
factor = 1.0 - torch.exp(ratio_delta) factor = 1.0 - torch.exp(ratio_delta)
return ( return (
(previous_noise_std / current_noise_std) * x (next_noise_std / current_noise_std) * x
+ previous_scale_factor * factor * current_data_estimation + next_scale_factor * factor * current_data_estimation
+ 0.5 * previous_scale_factor * factor * estimation_delta + 0.5 * next_scale_factor * factor * estimation_delta
) )
factor = 1.0 - torch.exp(2.0 * ratio_delta) factor = 1.0 - torch.exp(2.0 * ratio_delta)
return ( return (
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x (next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
+ previous_scale_factor * factor * current_data_estimation + next_scale_factor * factor * current_data_estimation
+ 0.5 * previous_scale_factor * factor * estimation_delta + 0.5 * next_scale_factor * factor * estimation_delta
+ previous_noise_std * torch.sqrt(factor) * sde_noise + next_noise_std * safe_sqrt(factor) * sde_noise
) )
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(
self, x: torch.Tensor, predicted_noise: torch.Tensor, step: int, generator: torch.Generator | None = None
) -> torch.Tensor:
"""Apply one step of the backward diffusion process. """Apply one step of the backward diffusion process.
Note: Note:
@ -199,9 +285,8 @@ class DPMSolver(Solver):
""" """
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
current_timestep = self.timesteps[step] scale_factor = self.cumulative_scale_factors[step]
scale_factor = self.cumulative_scale_factors[current_timestep] noise_ratio = self.noise_std[step]
noise_ratio = self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
self.estimated_data.append(estimated_denoised_data) self.estimated_data.append(estimated_denoised_data)
variance = self.params.sde_variance variance = self.params.sde_variance

View file

@ -67,6 +67,7 @@ class BaseSolverParams:
initial_diffusion_rate: float | None initial_diffusion_rate: float | None
final_diffusion_rate: float | None final_diffusion_rate: float | None
noise_schedule: NoiseSchedule | None noise_schedule: NoiseSchedule | None
sigma_schedule: NoiseSchedule | None
model_prediction_type: ModelPredictionType | None model_prediction_type: ModelPredictionType | None
sde_variance: float sde_variance: float
@ -91,6 +92,7 @@ class SolverParams(BaseSolverParams):
initial_diffusion_rate: float | None = None initial_diffusion_rate: float | None = None
final_diffusion_rate: float | None = None final_diffusion_rate: float | None = None
noise_schedule: NoiseSchedule | None = None noise_schedule: NoiseSchedule | None = None
sigma_schedule: NoiseSchedule | None = None
model_prediction_type: ModelPredictionType | None = None model_prediction_type: ModelPredictionType | None = None
sde_variance: float = 0.0 sde_variance: float = 0.0
@ -103,6 +105,7 @@ class ResolvedSolverParams(BaseSolverParams):
initial_diffusion_rate: float initial_diffusion_rate: float
final_diffusion_rate: float final_diffusion_rate: float
noise_schedule: NoiseSchedule noise_schedule: NoiseSchedule
sigma_schedule: NoiseSchedule | None
model_prediction_type: ModelPredictionType model_prediction_type: ModelPredictionType
sde_variance: float sde_variance: float
@ -140,6 +143,7 @@ class Solver(fl.Module, ABC):
initial_diffusion_rate=8.5e-4, initial_diffusion_rate=8.5e-4,
final_diffusion_rate=1.2e-2, final_diffusion_rate=1.2e-2,
noise_schedule=NoiseSchedule.QUADRATIC, noise_schedule=NoiseSchedule.QUADRATIC,
sigma_schedule=None,
model_prediction_type=ModelPredictionType.NOISE, model_prediction_type=ModelPredictionType.NOISE,
sde_variance=0.0, sde_variance=0.0,
) )
@ -404,14 +408,12 @@ class Solver(fl.Module, ABC):
A tensor representing the noise schedule. A tensor representing the noise schedule.
""" """
match self.params.noise_schedule: match self.params.noise_schedule:
case "uniform": case NoiseSchedule.UNIFORM:
return 1 - self.sample_power_distribution(1) return 1 - self.sample_power_distribution(1)
case "quadratic": case NoiseSchedule.QUADRATIC:
return 1 - self.sample_power_distribution(2) return 1 - self.sample_power_distribution(2)
case "karras": case NoiseSchedule.KARRAS:
return 1 - self.sample_power_distribution(7) return 1 - self.sample_power_distribution(7)
case _:
raise ValueError(f"Unknown noise schedule: {self.params.noise_schedule}")
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver": def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
"""Move the solver to the specified device and data type. """Move the solver to the specified device and data type.

View file

@ -97,6 +97,11 @@ def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB") return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB")
@pytest.fixture
def expected_image_std_sde_karras_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_sde_karras_random_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image: def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB") return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
@ -913,6 +918,39 @@ def test_diffusion_std_sde_random_init(
ensure_similar_images(predicted_image, expected_image_std_sde_random_init) ensure_similar_images(predicted_image, expected_image_std_sde_random_init)
@no_grad()
def test_diffusion_std_sde_karras_random_init(
sd15_std_sde: StableDiffusion_1, expected_image_std_sde_karras_random_init: Image.Image, test_device: torch.device
):
sd15 = sd15_std_sde
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.solver = DPMSolver(
num_inference_steps=18,
last_step_first_order=True,
params=SolverParams(sde_variance=1.0, sigma_schedule=NoiseSchedule.KARRAS),
device=test_device,
)
manual_seed(2)
x = sd15.init_latents((512, 512))
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_sde_karras_random_init)
@no_grad() @no_grad()
def test_diffusion_batch2(sd15_std: StableDiffusion_1): def test_diffusion_batch2(sd15_std: StableDiffusion_1):
sd15 = sd15_std sd15 = sd15_std

View file

@ -97,6 +97,29 @@ manual_seed(2)
image = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=7.5).images[0] image = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=7.5).images[0]
``` ```
- `expected_std_sde_karras_random_init.png` is generated with the following code (diffusers 0.30.2):
```python
import torch
from diffusers import StableDiffusionPipeline
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from refiners.fluxion.utils import manual_seed
model_id = "botp/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
pipe = pipe.to("cuda:1")
config = {**pipe.scheduler.config}
config["use_karras_sigmas"] = True
config["algorithm_type"] = "sde-dpmsolver++"
pipe.scheduler = DPMSolverMultistepScheduler.from_config(config)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
manual_seed(2)
image = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=18, guidance_scale=7.5).images[0]
```
- `kitchen_mask.png` is made manually. - `kitchen_mask.png` is made manually.
- Controlnet guides have been manually generated (x) using open source software and models, namely: - Controlnet guides have been manually generated (x) using open source software and models, namely:

Binary file not shown.

After

Width:  |  Height:  |  Size: 378 KiB

View file

@ -1,3 +1,4 @@
import itertools
from typing import cast from typing import cast
from warnings import warn from warnings import warn
@ -29,8 +30,11 @@ def test_ddpm_diffusers():
assert equal(diffusers_scheduler.timesteps, solver.timesteps) assert equal(diffusers_scheduler.timesteps, solver.timesteps)
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)]) @pytest.mark.parametrize(
def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool): "n_steps, last_step_first_order, sde_variance, use_karras_sigmas",
list(itertools.product([5, 30], [False, True], [0.0, 1.0], [False, True])),
)
def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool, sde_variance: float, use_karras_sigmas: bool):
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
manual_seed(0) manual_seed(0)
@ -42,43 +46,17 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
lower_order_final=False, lower_order_final=False,
euler_at_final=last_step_first_order, euler_at_final=last_step_first_order,
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0 final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
algorithm_type="sde-dpmsolver++" if sde_variance == 1.0 else "dpmsolver++",
use_karras_sigmas=use_karras_sigmas,
) )
diffusers_scheduler.set_timesteps(n_steps) diffusers_scheduler.set_timesteps(n_steps)
solver = DPMSolver( solver = DPMSolver(
num_inference_steps=n_steps, num_inference_steps=n_steps,
last_step_first_order=last_step_first_order, last_step_first_order=last_step_first_order,
) params=SolverParams(
assert equal(solver.timesteps, diffusers_scheduler.timesteps) sde_variance=sde_variance,
sigma_schedule=NoiseSchedule.KARRAS if use_karras_sigmas else None,
sample = randn(1, 3, 32, 32) ),
predicted_noise = randn(1, 3, 32, 32)
for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
refiners_output = solver(x=sample, predicted_noise=predicted_noise, step=step)
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool):
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
manual_seed(0)
diffusers_scheduler = DiffuserScheduler(
beta_schedule="scaled_linear",
beta_start=0.00085,
beta_end=0.012,
lower_order_final=False,
euler_at_final=last_step_first_order,
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
algorithm_type="sde-dpmsolver++",
)
diffusers_scheduler.set_timesteps(n_steps)
solver = DPMSolver(
num_inference_steps=n_steps,
last_step_first_order=last_step_first_order,
params=SolverParams(sde_variance=1.0),
) )
assert equal(solver.timesteps, diffusers_scheduler.timesteps) assert equal(solver.timesteps, diffusers_scheduler.timesteps)
@ -94,8 +72,9 @@ def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool):
manual_seed(37) manual_seed(37)
refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)] refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)]
atol = 1e-4 if use_karras_sigmas else 1e-6
for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)): for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)):
assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=1e-6), f"outputs differ at step {step}" assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=atol), f"outputs differ at step {step}"
def test_ddim_diffusers(): def test_ddim_diffusers():