add karras sigmas to dpm solver

This commit is contained in:
limiteinductive 2024-09-06 10:56:24 +00:00 committed by Benjamin Trom
parent 5aef1408d8
commit af6c5aecbe
6 changed files with 215 additions and 88 deletions

View file

@ -1,18 +1,35 @@
import dataclasses
from collections import deque
from typing import NamedTuple
import numpy as np
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype
from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
ModelPredictionType,
NoiseSchedule,
Solver,
TimestepSpacing,
)
def safe_log(x: torch.Tensor, lower_bound: float = 1e-6) -> torch.Tensor:
"""Compute the log of a tensor with a lower bound."""
return torch.log(torch.maximum(x, torch.tensor(lower_bound)))
def safe_sqrt(x: torch.Tensor) -> torch.Tensor:
"""Compute the square root of a tensor ensuring that the input is non-negative"""
return torch.sqrt(torch.maximum(x, torch.tensor(0)))
class SolverTensors(NamedTuple):
cumulative_scale_factors: torch.Tensor
noise_std: torch.Tensor
signal_to_noise_ratios: torch.Tensor
class DPMSolver(Solver):
"""Diffusion probabilistic models (DPMs) solver.
@ -37,9 +54,9 @@ class DPMSolver(Solver):
first_inference_step: int = 0,
params: BaseSolverParams | None = None,
last_step_first_order: bool = False,
device: Device | str = "cpu",
dtype: Dtype = torch.float32,
):
device: torch.device | str = "cpu",
dtype: torch.dtype = torch.float32,
) -> None:
"""Initializes a new DPM solver.
Args:
@ -64,6 +81,14 @@ class DPMSolver(Solver):
)
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
self.last_step_first_order = last_step_first_order
sigmas = self.noise_std / self.cumulative_scale_factors
self.sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)
sigma_min = sigmas[0:1] # corresponds to `final_sigmas_type="sigma_min" in diffusers`
self.sigmas = torch.cat([self.sigmas, sigma_min])
self.cumulative_scale_factors, self.noise_std, self.signal_to_noise_ratios = self._solver_tensors_from_sigmas(
self.sigmas
)
self.timesteps = self._timesteps_from_sigmas(sigmas)
def rebuild(
self: "DPMSolver",
@ -83,7 +108,7 @@ class DPMSolver(Solver):
r.last_step_first_order = self.last_step_first_order
return r
def _generate_timesteps(self) -> Tensor:
def _generate_timesteps(self) -> torch.Tensor:
if self.params.timesteps_spacing != TimestepSpacing.CUSTOM:
return super()._generate_timesteps()
@ -96,9 +121,75 @@ class DPMSolver(Solver):
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
return torch.tensor(np_space).flip(0)
def _generate_sigmas(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Generate the sigmas used by the solver."""
assert self.params.sigma_schedule is not None, "sigma_schedule must be set for the DPM solver"
sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = sigmas.flip(0)
rescaled_sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)
rescaled_sigmas = torch.cat([rescaled_sigmas, torch.tensor([0.0])])
return sigmas, rescaled_sigmas
def _rescale_sigmas(self, sigmas: torch.Tensor, sigma_schedule: NoiseSchedule | None) -> torch.Tensor:
"""Rescale the sigmas according to the sigma schedule."""
match sigma_schedule:
case NoiseSchedule.UNIFORM:
rho = 1
case NoiseSchedule.QUADRATIC:
rho = 2
case NoiseSchedule.KARRAS:
rho = 7
case None:
return torch.tensor(
np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()),
device=self.device,
)
linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device)
first_sigma = sigmas[0]
last_sigma = sigmas[-1]
rescaled_sigmas = (
first_sigma ** (1 / rho) + linear_schedule * (last_sigma ** (1 / rho) - first_sigma ** (1 / rho))
) ** rho
return rescaled_sigmas.flip(0)
def _timesteps_from_sigmas(self, sigmas: torch.Tensor) -> torch.Tensor:
"""Generate the timesteps from the sigmas."""
log_sigmas = safe_log(sigmas)
timesteps: list[torch.Tensor] = []
for sigma in self.sigmas[:-1]:
log_sigma = safe_log(sigma)
distance_matrix = log_sigma - log_sigmas.unsqueeze(1)
# Determine the range of sigma indices
low_indices = (distance_matrix >= 0).cumsum(dim=0).argmax(dim=0).clip(max=sigmas.size(0) - 2)
high_indices = low_indices + 1
low_log_sigma = log_sigmas[low_indices]
high_log_sigma = log_sigmas[high_indices]
# Interpolate sigma values
interpolation_weights = (low_log_sigma - log_sigma) / (low_log_sigma - high_log_sigma)
interpolation_weights = torch.clamp(interpolation_weights, 0, 1)
timestep = (1 - interpolation_weights) * low_indices + interpolation_weights * high_indices
timesteps.append(timestep)
return torch.cat(timesteps).round()
def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors:
"""Generate the tensors from the sigmas."""
cumulative_scale_factors = 1 / torch.sqrt(sigmas**2 + 1)
noise_std = sigmas * cumulative_scale_factors
signal_to_noise_ratios = safe_log(cumulative_scale_factors) - safe_log(noise_std)
return SolverTensors(
cumulative_scale_factors=cumulative_scale_factors,
noise_std=noise_std,
signal_to_noise_ratios=signal_to_noise_ratios,
)
def dpm_solver_first_order_update(
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
) -> Tensor:
self, x: torch.Tensor, noise: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None
) -> torch.Tensor:
"""Applies a first-order backward Euler update to the input data `x`.
Args:
@ -109,32 +200,29 @@ class DPMSolver(Solver):
Returns:
The denoised version of the input data `x`.
"""
current_timestep = self.timesteps[step]
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
current_ratio = self.signal_to_noise_ratios[step]
next_ratio = self.signal_to_noise_ratios[step + 1]
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
current_ratio = self.signal_to_noise_ratios[current_timestep]
next_scale_factor = self.cumulative_scale_factors[step + 1]
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
next_noise_std = self.noise_std[step + 1]
current_noise_std = self.noise_std[step]
previous_noise_std = self.noise_std[previous_timestep]
current_noise_std = self.noise_std[current_timestep]
ratio_delta = current_ratio - previous_ratio
ratio_delta = current_ratio - next_ratio
if sde_noise is None:
return (previous_noise_std / current_noise_std) * x + (
1.0 - torch.exp(ratio_delta)
) * previous_scale_factor * noise
return (next_noise_std / current_noise_std) * x + (1.0 - torch.exp(ratio_delta)) * next_scale_factor * noise
factor = 1.0 - torch.exp(2.0 * ratio_delta)
return (
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
+ previous_scale_factor * factor * noise
+ previous_noise_std * torch.sqrt(factor) * sde_noise
(next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
+ next_scale_factor * factor * noise
+ next_noise_std * safe_sqrt(factor) * sde_noise
)
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noise: Tensor | None = None) -> Tensor:
def multistep_dpm_solver_second_order_update(
self, x: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None
) -> torch.Tensor:
"""Applies a second-order backward Euler update to the input data `x`.
Args:
@ -144,43 +232,41 @@ class DPMSolver(Solver):
Returns:
The denoised version of the input data `x`.
"""
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
current_timestep = self.timesteps[step]
next_timestep = self.timesteps[step - 1]
current_data_estimation = self.estimated_data[-1]
next_data_estimation = self.estimated_data[-2]
previous_data_estimation = self.estimated_data[-2]
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
current_ratio = self.signal_to_noise_ratios[current_timestep]
next_ratio = self.signal_to_noise_ratios[next_timestep]
next_ratio = self.signal_to_noise_ratios[step + 1]
current_ratio = self.signal_to_noise_ratios[step]
previous_ratio = self.signal_to_noise_ratios[step - 1]
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_noise_std = self.noise_std[previous_timestep]
current_noise_std = self.noise_std[current_timestep]
next_scale_factor = self.cumulative_scale_factors[step + 1]
next_noise_std = self.noise_std[step + 1]
current_noise_std = self.noise_std[step]
estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
estimation_delta = (current_data_estimation - previous_data_estimation) / (
(current_ratio - previous_ratio) / (next_ratio - current_ratio)
)
ratio_delta = current_ratio - previous_ratio
ratio_delta = current_ratio - next_ratio
if sde_noise is None:
factor = 1.0 - torch.exp(ratio_delta)
return (
(previous_noise_std / current_noise_std) * x
+ previous_scale_factor * factor * current_data_estimation
+ 0.5 * previous_scale_factor * factor * estimation_delta
(next_noise_std / current_noise_std) * x
+ next_scale_factor * factor * current_data_estimation
+ 0.5 * next_scale_factor * factor * estimation_delta
)
factor = 1.0 - torch.exp(2.0 * ratio_delta)
return (
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
+ previous_scale_factor * factor * current_data_estimation
+ 0.5 * previous_scale_factor * factor * estimation_delta
+ previous_noise_std * torch.sqrt(factor) * sde_noise
(next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
+ next_scale_factor * factor * current_data_estimation
+ 0.5 * next_scale_factor * factor * estimation_delta
+ next_noise_std * safe_sqrt(factor) * sde_noise
)
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
def __call__(
self, x: torch.Tensor, predicted_noise: torch.Tensor, step: int, generator: torch.Generator | None = None
) -> torch.Tensor:
"""Apply one step of the backward diffusion process.
Note:
@ -199,9 +285,8 @@ class DPMSolver(Solver):
"""
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
current_timestep = self.timesteps[step]
scale_factor = self.cumulative_scale_factors[current_timestep]
noise_ratio = self.noise_std[current_timestep]
scale_factor = self.cumulative_scale_factors[step]
noise_ratio = self.noise_std[step]
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
self.estimated_data.append(estimated_denoised_data)
variance = self.params.sde_variance

View file

@ -67,6 +67,7 @@ class BaseSolverParams:
initial_diffusion_rate: float | None
final_diffusion_rate: float | None
noise_schedule: NoiseSchedule | None
sigma_schedule: NoiseSchedule | None
model_prediction_type: ModelPredictionType | None
sde_variance: float
@ -91,6 +92,7 @@ class SolverParams(BaseSolverParams):
initial_diffusion_rate: float | None = None
final_diffusion_rate: float | None = None
noise_schedule: NoiseSchedule | None = None
sigma_schedule: NoiseSchedule | None = None
model_prediction_type: ModelPredictionType | None = None
sde_variance: float = 0.0
@ -103,6 +105,7 @@ class ResolvedSolverParams(BaseSolverParams):
initial_diffusion_rate: float
final_diffusion_rate: float
noise_schedule: NoiseSchedule
sigma_schedule: NoiseSchedule | None
model_prediction_type: ModelPredictionType
sde_variance: float
@ -140,6 +143,7 @@ class Solver(fl.Module, ABC):
initial_diffusion_rate=8.5e-4,
final_diffusion_rate=1.2e-2,
noise_schedule=NoiseSchedule.QUADRATIC,
sigma_schedule=None,
model_prediction_type=ModelPredictionType.NOISE,
sde_variance=0.0,
)
@ -404,14 +408,12 @@ class Solver(fl.Module, ABC):
A tensor representing the noise schedule.
"""
match self.params.noise_schedule:
case "uniform":
case NoiseSchedule.UNIFORM:
return 1 - self.sample_power_distribution(1)
case "quadratic":
case NoiseSchedule.QUADRATIC:
return 1 - self.sample_power_distribution(2)
case "karras":
case NoiseSchedule.KARRAS:
return 1 - self.sample_power_distribution(7)
case _:
raise ValueError(f"Unknown noise schedule: {self.params.noise_schedule}")
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
"""Move the solver to the specified device and data type.

View file

@ -97,6 +97,11 @@ def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB")
@pytest.fixture
def expected_image_std_sde_karras_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_sde_karras_random_init.png").convert("RGB")
@pytest.fixture
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
@ -913,6 +918,39 @@ def test_diffusion_std_sde_random_init(
ensure_similar_images(predicted_image, expected_image_std_sde_random_init)
@no_grad()
def test_diffusion_std_sde_karras_random_init(
sd15_std_sde: StableDiffusion_1, expected_image_std_sde_karras_random_init: Image.Image, test_device: torch.device
):
sd15 = sd15_std_sde
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.solver = DPMSolver(
num_inference_steps=18,
last_step_first_order=True,
params=SolverParams(sde_variance=1.0, sigma_schedule=NoiseSchedule.KARRAS),
device=test_device,
)
manual_seed(2)
x = sd15.init_latents((512, 512))
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_sde_karras_random_init)
@no_grad()
def test_diffusion_batch2(sd15_std: StableDiffusion_1):
sd15 = sd15_std

View file

@ -97,6 +97,29 @@ manual_seed(2)
image = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=7.5).images[0]
```
- `expected_std_sde_karras_random_init.png` is generated with the following code (diffusers 0.30.2):
```python
import torch
from diffusers import StableDiffusionPipeline
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from refiners.fluxion.utils import manual_seed
model_id = "botp/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
pipe = pipe.to("cuda:1")
config = {**pipe.scheduler.config}
config["use_karras_sigmas"] = True
config["algorithm_type"] = "sde-dpmsolver++"
pipe.scheduler = DPMSolverMultistepScheduler.from_config(config)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
manual_seed(2)
image = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=18, guidance_scale=7.5).images[0]
```
- `kitchen_mask.png` is made manually.
- Controlnet guides have been manually generated (x) using open source software and models, namely:

Binary file not shown.

After

Width:  |  Height:  |  Size: 378 KiB

View file

@ -1,3 +1,4 @@
import itertools
from typing import cast
from warnings import warn
@ -29,8 +30,11 @@ def test_ddpm_diffusers():
assert equal(diffusers_scheduler.timesteps, solver.timesteps)
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
@pytest.mark.parametrize(
"n_steps, last_step_first_order, sde_variance, use_karras_sigmas",
list(itertools.product([5, 30], [False, True], [0.0, 1.0], [False, True])),
)
def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool, sde_variance: float, use_karras_sigmas: bool):
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
manual_seed(0)
@ -42,43 +46,17 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
lower_order_final=False,
euler_at_final=last_step_first_order,
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
algorithm_type="sde-dpmsolver++" if sde_variance == 1.0 else "dpmsolver++",
use_karras_sigmas=use_karras_sigmas,
)
diffusers_scheduler.set_timesteps(n_steps)
solver = DPMSolver(
num_inference_steps=n_steps,
last_step_first_order=last_step_first_order,
)
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 3, 32, 32)
predicted_noise = randn(1, 3, 32, 32)
for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
refiners_output = solver(x=sample, predicted_noise=predicted_noise, step=step)
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool):
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
manual_seed(0)
diffusers_scheduler = DiffuserScheduler(
beta_schedule="scaled_linear",
beta_start=0.00085,
beta_end=0.012,
lower_order_final=False,
euler_at_final=last_step_first_order,
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
algorithm_type="sde-dpmsolver++",
)
diffusers_scheduler.set_timesteps(n_steps)
solver = DPMSolver(
num_inference_steps=n_steps,
last_step_first_order=last_step_first_order,
params=SolverParams(sde_variance=1.0),
params=SolverParams(
sde_variance=sde_variance,
sigma_schedule=NoiseSchedule.KARRAS if use_karras_sigmas else None,
),
)
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
@ -94,8 +72,9 @@ def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool):
manual_seed(37)
refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)]
atol = 1e-4 if use_karras_sigmas else 1e-6
for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)):
assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=1e-6), f"outputs differ at step {step}"
assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=atol), f"outputs differ at step {step}"
def test_ddim_diffusers():