mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
add karras sigmas to dpm solver
This commit is contained in:
parent
5aef1408d8
commit
af6c5aecbe
|
@ -1,18 +1,35 @@
|
|||
import dataclasses
|
||||
from collections import deque
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
BaseSolverParams,
|
||||
ModelPredictionType,
|
||||
NoiseSchedule,
|
||||
Solver,
|
||||
TimestepSpacing,
|
||||
)
|
||||
|
||||
|
||||
def safe_log(x: torch.Tensor, lower_bound: float = 1e-6) -> torch.Tensor:
|
||||
"""Compute the log of a tensor with a lower bound."""
|
||||
return torch.log(torch.maximum(x, torch.tensor(lower_bound)))
|
||||
|
||||
|
||||
def safe_sqrt(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the square root of a tensor ensuring that the input is non-negative"""
|
||||
return torch.sqrt(torch.maximum(x, torch.tensor(0)))
|
||||
|
||||
|
||||
class SolverTensors(NamedTuple):
|
||||
cumulative_scale_factors: torch.Tensor
|
||||
noise_std: torch.Tensor
|
||||
signal_to_noise_ratios: torch.Tensor
|
||||
|
||||
|
||||
class DPMSolver(Solver):
|
||||
"""Diffusion probabilistic models (DPMs) solver.
|
||||
|
||||
|
@ -37,9 +54,9 @@ class DPMSolver(Solver):
|
|||
first_inference_step: int = 0,
|
||||
params: BaseSolverParams | None = None,
|
||||
last_step_first_order: bool = False,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = torch.float32,
|
||||
):
|
||||
device: torch.device | str = "cpu",
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> None:
|
||||
"""Initializes a new DPM solver.
|
||||
|
||||
Args:
|
||||
|
@ -64,6 +81,14 @@ class DPMSolver(Solver):
|
|||
)
|
||||
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
|
||||
self.last_step_first_order = last_step_first_order
|
||||
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||
self.sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)
|
||||
sigma_min = sigmas[0:1] # corresponds to `final_sigmas_type="sigma_min" in diffusers`
|
||||
self.sigmas = torch.cat([self.sigmas, sigma_min])
|
||||
self.cumulative_scale_factors, self.noise_std, self.signal_to_noise_ratios = self._solver_tensors_from_sigmas(
|
||||
self.sigmas
|
||||
)
|
||||
self.timesteps = self._timesteps_from_sigmas(sigmas)
|
||||
|
||||
def rebuild(
|
||||
self: "DPMSolver",
|
||||
|
@ -83,7 +108,7 @@ class DPMSolver(Solver):
|
|||
r.last_step_first_order = self.last_step_first_order
|
||||
return r
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
def _generate_timesteps(self) -> torch.Tensor:
|
||||
if self.params.timesteps_spacing != TimestepSpacing.CUSTOM:
|
||||
return super()._generate_timesteps()
|
||||
|
||||
|
@ -96,9 +121,75 @@ class DPMSolver(Solver):
|
|||
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
|
||||
return torch.tensor(np_space).flip(0)
|
||||
|
||||
def _generate_sigmas(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Generate the sigmas used by the solver."""
|
||||
assert self.params.sigma_schedule is not None, "sigma_schedule must be set for the DPM solver"
|
||||
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||
sigmas = sigmas.flip(0)
|
||||
rescaled_sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)
|
||||
rescaled_sigmas = torch.cat([rescaled_sigmas, torch.tensor([0.0])])
|
||||
return sigmas, rescaled_sigmas
|
||||
|
||||
def _rescale_sigmas(self, sigmas: torch.Tensor, sigma_schedule: NoiseSchedule | None) -> torch.Tensor:
|
||||
"""Rescale the sigmas according to the sigma schedule."""
|
||||
match sigma_schedule:
|
||||
case NoiseSchedule.UNIFORM:
|
||||
rho = 1
|
||||
case NoiseSchedule.QUADRATIC:
|
||||
rho = 2
|
||||
case NoiseSchedule.KARRAS:
|
||||
rho = 7
|
||||
case None:
|
||||
return torch.tensor(
|
||||
np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()),
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device)
|
||||
first_sigma = sigmas[0]
|
||||
last_sigma = sigmas[-1]
|
||||
rescaled_sigmas = (
|
||||
first_sigma ** (1 / rho) + linear_schedule * (last_sigma ** (1 / rho) - first_sigma ** (1 / rho))
|
||||
) ** rho
|
||||
return rescaled_sigmas.flip(0)
|
||||
|
||||
def _timesteps_from_sigmas(self, sigmas: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate the timesteps from the sigmas."""
|
||||
log_sigmas = safe_log(sigmas)
|
||||
timesteps: list[torch.Tensor] = []
|
||||
for sigma in self.sigmas[:-1]:
|
||||
log_sigma = safe_log(sigma)
|
||||
distance_matrix = log_sigma - log_sigmas.unsqueeze(1)
|
||||
|
||||
# Determine the range of sigma indices
|
||||
low_indices = (distance_matrix >= 0).cumsum(dim=0).argmax(dim=0).clip(max=sigmas.size(0) - 2)
|
||||
high_indices = low_indices + 1
|
||||
|
||||
low_log_sigma = log_sigmas[low_indices]
|
||||
high_log_sigma = log_sigmas[high_indices]
|
||||
|
||||
# Interpolate sigma values
|
||||
interpolation_weights = (low_log_sigma - log_sigma) / (low_log_sigma - high_log_sigma)
|
||||
interpolation_weights = torch.clamp(interpolation_weights, 0, 1)
|
||||
timestep = (1 - interpolation_weights) * low_indices + interpolation_weights * high_indices
|
||||
timesteps.append(timestep)
|
||||
|
||||
return torch.cat(timesteps).round()
|
||||
|
||||
def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors:
|
||||
"""Generate the tensors from the sigmas."""
|
||||
cumulative_scale_factors = 1 / torch.sqrt(sigmas**2 + 1)
|
||||
noise_std = sigmas * cumulative_scale_factors
|
||||
signal_to_noise_ratios = safe_log(cumulative_scale_factors) - safe_log(noise_std)
|
||||
return SolverTensors(
|
||||
cumulative_scale_factors=cumulative_scale_factors,
|
||||
noise_std=noise_std,
|
||||
signal_to_noise_ratios=signal_to_noise_ratios,
|
||||
)
|
||||
|
||||
def dpm_solver_first_order_update(
|
||||
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
|
||||
) -> Tensor:
|
||||
self, x: torch.Tensor, noise: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Applies a first-order backward Euler update to the input data `x`.
|
||||
|
||||
Args:
|
||||
|
@ -109,32 +200,29 @@ class DPMSolver(Solver):
|
|||
Returns:
|
||||
The denoised version of the input data `x`.
|
||||
"""
|
||||
current_timestep = self.timesteps[step]
|
||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
|
||||
current_ratio = self.signal_to_noise_ratios[step]
|
||||
next_ratio = self.signal_to_noise_ratios[step + 1]
|
||||
|
||||
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
|
||||
current_ratio = self.signal_to_noise_ratios[current_timestep]
|
||||
next_scale_factor = self.cumulative_scale_factors[step + 1]
|
||||
|
||||
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
|
||||
next_noise_std = self.noise_std[step + 1]
|
||||
current_noise_std = self.noise_std[step]
|
||||
|
||||
previous_noise_std = self.noise_std[previous_timestep]
|
||||
current_noise_std = self.noise_std[current_timestep]
|
||||
|
||||
ratio_delta = current_ratio - previous_ratio
|
||||
ratio_delta = current_ratio - next_ratio
|
||||
|
||||
if sde_noise is None:
|
||||
return (previous_noise_std / current_noise_std) * x + (
|
||||
1.0 - torch.exp(ratio_delta)
|
||||
) * previous_scale_factor * noise
|
||||
return (next_noise_std / current_noise_std) * x + (1.0 - torch.exp(ratio_delta)) * next_scale_factor * noise
|
||||
|
||||
factor = 1.0 - torch.exp(2.0 * ratio_delta)
|
||||
return (
|
||||
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
|
||||
+ previous_scale_factor * factor * noise
|
||||
+ previous_noise_std * torch.sqrt(factor) * sde_noise
|
||||
(next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
|
||||
+ next_scale_factor * factor * noise
|
||||
+ next_noise_std * safe_sqrt(factor) * sde_noise
|
||||
)
|
||||
|
||||
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noise: Tensor | None = None) -> Tensor:
|
||||
def multistep_dpm_solver_second_order_update(
|
||||
self, x: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Applies a second-order backward Euler update to the input data `x`.
|
||||
|
||||
Args:
|
||||
|
@ -144,43 +232,41 @@ class DPMSolver(Solver):
|
|||
Returns:
|
||||
The denoised version of the input data `x`.
|
||||
"""
|
||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
|
||||
current_timestep = self.timesteps[step]
|
||||
next_timestep = self.timesteps[step - 1]
|
||||
|
||||
current_data_estimation = self.estimated_data[-1]
|
||||
next_data_estimation = self.estimated_data[-2]
|
||||
previous_data_estimation = self.estimated_data[-2]
|
||||
|
||||
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
|
||||
current_ratio = self.signal_to_noise_ratios[current_timestep]
|
||||
next_ratio = self.signal_to_noise_ratios[next_timestep]
|
||||
next_ratio = self.signal_to_noise_ratios[step + 1]
|
||||
current_ratio = self.signal_to_noise_ratios[step]
|
||||
previous_ratio = self.signal_to_noise_ratios[step - 1]
|
||||
|
||||
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
|
||||
previous_noise_std = self.noise_std[previous_timestep]
|
||||
current_noise_std = self.noise_std[current_timestep]
|
||||
next_scale_factor = self.cumulative_scale_factors[step + 1]
|
||||
next_noise_std = self.noise_std[step + 1]
|
||||
current_noise_std = self.noise_std[step]
|
||||
|
||||
estimation_delta = (current_data_estimation - next_data_estimation) / (
|
||||
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
|
||||
estimation_delta = (current_data_estimation - previous_data_estimation) / (
|
||||
(current_ratio - previous_ratio) / (next_ratio - current_ratio)
|
||||
)
|
||||
ratio_delta = current_ratio - previous_ratio
|
||||
ratio_delta = current_ratio - next_ratio
|
||||
|
||||
if sde_noise is None:
|
||||
factor = 1.0 - torch.exp(ratio_delta)
|
||||
return (
|
||||
(previous_noise_std / current_noise_std) * x
|
||||
+ previous_scale_factor * factor * current_data_estimation
|
||||
+ 0.5 * previous_scale_factor * factor * estimation_delta
|
||||
(next_noise_std / current_noise_std) * x
|
||||
+ next_scale_factor * factor * current_data_estimation
|
||||
+ 0.5 * next_scale_factor * factor * estimation_delta
|
||||
)
|
||||
|
||||
factor = 1.0 - torch.exp(2.0 * ratio_delta)
|
||||
return (
|
||||
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
|
||||
+ previous_scale_factor * factor * current_data_estimation
|
||||
+ 0.5 * previous_scale_factor * factor * estimation_delta
|
||||
+ previous_noise_std * torch.sqrt(factor) * sde_noise
|
||||
(next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
|
||||
+ next_scale_factor * factor * current_data_estimation
|
||||
+ 0.5 * next_scale_factor * factor * estimation_delta
|
||||
+ next_noise_std * safe_sqrt(factor) * sde_noise
|
||||
)
|
||||
|
||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
def __call__(
|
||||
self, x: torch.Tensor, predicted_noise: torch.Tensor, step: int, generator: torch.Generator | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Apply one step of the backward diffusion process.
|
||||
|
||||
Note:
|
||||
|
@ -199,9 +285,8 @@ class DPMSolver(Solver):
|
|||
"""
|
||||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||
|
||||
current_timestep = self.timesteps[step]
|
||||
scale_factor = self.cumulative_scale_factors[current_timestep]
|
||||
noise_ratio = self.noise_std[current_timestep]
|
||||
scale_factor = self.cumulative_scale_factors[step]
|
||||
noise_ratio = self.noise_std[step]
|
||||
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
|
||||
self.estimated_data.append(estimated_denoised_data)
|
||||
variance = self.params.sde_variance
|
||||
|
|
|
@ -67,6 +67,7 @@ class BaseSolverParams:
|
|||
initial_diffusion_rate: float | None
|
||||
final_diffusion_rate: float | None
|
||||
noise_schedule: NoiseSchedule | None
|
||||
sigma_schedule: NoiseSchedule | None
|
||||
model_prediction_type: ModelPredictionType | None
|
||||
sde_variance: float
|
||||
|
||||
|
@ -91,6 +92,7 @@ class SolverParams(BaseSolverParams):
|
|||
initial_diffusion_rate: float | None = None
|
||||
final_diffusion_rate: float | None = None
|
||||
noise_schedule: NoiseSchedule | None = None
|
||||
sigma_schedule: NoiseSchedule | None = None
|
||||
model_prediction_type: ModelPredictionType | None = None
|
||||
sde_variance: float = 0.0
|
||||
|
||||
|
@ -103,6 +105,7 @@ class ResolvedSolverParams(BaseSolverParams):
|
|||
initial_diffusion_rate: float
|
||||
final_diffusion_rate: float
|
||||
noise_schedule: NoiseSchedule
|
||||
sigma_schedule: NoiseSchedule | None
|
||||
model_prediction_type: ModelPredictionType
|
||||
sde_variance: float
|
||||
|
||||
|
@ -140,6 +143,7 @@ class Solver(fl.Module, ABC):
|
|||
initial_diffusion_rate=8.5e-4,
|
||||
final_diffusion_rate=1.2e-2,
|
||||
noise_schedule=NoiseSchedule.QUADRATIC,
|
||||
sigma_schedule=None,
|
||||
model_prediction_type=ModelPredictionType.NOISE,
|
||||
sde_variance=0.0,
|
||||
)
|
||||
|
@ -404,14 +408,12 @@ class Solver(fl.Module, ABC):
|
|||
A tensor representing the noise schedule.
|
||||
"""
|
||||
match self.params.noise_schedule:
|
||||
case "uniform":
|
||||
case NoiseSchedule.UNIFORM:
|
||||
return 1 - self.sample_power_distribution(1)
|
||||
case "quadratic":
|
||||
case NoiseSchedule.QUADRATIC:
|
||||
return 1 - self.sample_power_distribution(2)
|
||||
case "karras":
|
||||
case NoiseSchedule.KARRAS:
|
||||
return 1 - self.sample_power_distribution(7)
|
||||
case _:
|
||||
raise ValueError(f"Unknown noise schedule: {self.params.noise_schedule}")
|
||||
|
||||
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
|
||||
"""Move the solver to the specified device and data type.
|
||||
|
|
|
@ -97,6 +97,11 @@ def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image:
|
|||
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_std_sde_karras_random_init(ref_path: Path) -> Image.Image:
|
||||
return _img_open(ref_path / "expected_std_sde_karras_random_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
|
||||
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
|
||||
|
@ -913,6 +918,39 @@ def test_diffusion_std_sde_random_init(
|
|||
ensure_similar_images(predicted_image, expected_image_std_sde_random_init)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_std_sde_karras_random_init(
|
||||
sd15_std_sde: StableDiffusion_1, expected_image_std_sde_karras_random_init: Image.Image, test_device: torch.device
|
||||
):
|
||||
sd15 = sd15_std_sde
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.solver = DPMSolver(
|
||||
num_inference_steps=18,
|
||||
last_step_first_order=True,
|
||||
params=SolverParams(sde_variance=1.0, sigma_schedule=NoiseSchedule.KARRAS),
|
||||
device=test_device,
|
||||
)
|
||||
|
||||
manual_seed(2)
|
||||
x = sd15.init_latents((512, 512))
|
||||
|
||||
for step in sd15.steps:
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_std_sde_karras_random_init)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_batch2(sd15_std: StableDiffusion_1):
|
||||
sd15 = sd15_std
|
||||
|
|
|
@ -97,6 +97,29 @@ manual_seed(2)
|
|||
image = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=7.5).images[0]
|
||||
```
|
||||
|
||||
- `expected_std_sde_karras_random_init.png` is generated with the following code (diffusers 0.30.2):
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
||||
from refiners.fluxion.utils import manual_seed
|
||||
|
||||
model_id = "botp/stable-diffusion-v1-5"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
|
||||
pipe = pipe.to("cuda:1")
|
||||
|
||||
config = {**pipe.scheduler.config}
|
||||
config["use_karras_sigmas"] = True
|
||||
config["algorithm_type"] = "sde-dpmsolver++"
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(config)
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
manual_seed(2)
|
||||
image = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=18, guidance_scale=7.5).images[0]
|
||||
```
|
||||
|
||||
- `kitchen_mask.png` is made manually.
|
||||
|
||||
- Controlnet guides have been manually generated (x) using open source software and models, namely:
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 378 KiB |
|
@ -1,3 +1,4 @@
|
|||
import itertools
|
||||
from typing import cast
|
||||
from warnings import warn
|
||||
|
||||
|
@ -29,8 +30,11 @@ def test_ddpm_diffusers():
|
|||
assert equal(diffusers_scheduler.timesteps, solver.timesteps)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
|
||||
def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
|
||||
@pytest.mark.parametrize(
|
||||
"n_steps, last_step_first_order, sde_variance, use_karras_sigmas",
|
||||
list(itertools.product([5, 30], [False, True], [0.0, 1.0], [False, True])),
|
||||
)
|
||||
def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool, sde_variance: float, use_karras_sigmas: bool):
|
||||
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
|
||||
|
||||
manual_seed(0)
|
||||
|
@ -42,43 +46,17 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
|
|||
lower_order_final=False,
|
||||
euler_at_final=last_step_first_order,
|
||||
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
|
||||
algorithm_type="sde-dpmsolver++" if sde_variance == 1.0 else "dpmsolver++",
|
||||
use_karras_sigmas=use_karras_sigmas,
|
||||
)
|
||||
diffusers_scheduler.set_timesteps(n_steps)
|
||||
solver = DPMSolver(
|
||||
num_inference_steps=n_steps,
|
||||
last_step_first_order=last_step_first_order,
|
||||
)
|
||||
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
|
||||
|
||||
sample = randn(1, 3, 32, 32)
|
||||
predicted_noise = randn(1, 3, 32, 32)
|
||||
|
||||
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
||||
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
|
||||
refiners_output = solver(x=sample, predicted_noise=predicted_noise, step=step)
|
||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
|
||||
def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool):
|
||||
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
|
||||
|
||||
manual_seed(0)
|
||||
|
||||
diffusers_scheduler = DiffuserScheduler(
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
lower_order_final=False,
|
||||
euler_at_final=last_step_first_order,
|
||||
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
|
||||
algorithm_type="sde-dpmsolver++",
|
||||
)
|
||||
diffusers_scheduler.set_timesteps(n_steps)
|
||||
solver = DPMSolver(
|
||||
num_inference_steps=n_steps,
|
||||
last_step_first_order=last_step_first_order,
|
||||
params=SolverParams(sde_variance=1.0),
|
||||
params=SolverParams(
|
||||
sde_variance=sde_variance,
|
||||
sigma_schedule=NoiseSchedule.KARRAS if use_karras_sigmas else None,
|
||||
),
|
||||
)
|
||||
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
|
||||
|
||||
|
@ -94,8 +72,9 @@ def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool):
|
|||
manual_seed(37)
|
||||
refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)]
|
||||
|
||||
atol = 1e-4 if use_karras_sigmas else 1e-6
|
||||
for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)):
|
||||
assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=1e-6), f"outputs differ at step {step}"
|
||||
assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=atol), f"outputs differ at step {step}"
|
||||
|
||||
|
||||
def test_ddim_diffusers():
|
||||
|
|
Loading…
Reference in a new issue