mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
add karras sigmas to dpm solver
This commit is contained in:
parent
5aef1408d8
commit
af6c5aecbe
|
@ -1,18 +1,35 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Generator, Tensor, device as Device, dtype as Dtype
|
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
BaseSolverParams,
|
BaseSolverParams,
|
||||||
ModelPredictionType,
|
ModelPredictionType,
|
||||||
|
NoiseSchedule,
|
||||||
Solver,
|
Solver,
|
||||||
TimestepSpacing,
|
TimestepSpacing,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_log(x: torch.Tensor, lower_bound: float = 1e-6) -> torch.Tensor:
|
||||||
|
"""Compute the log of a tensor with a lower bound."""
|
||||||
|
return torch.log(torch.maximum(x, torch.tensor(lower_bound)))
|
||||||
|
|
||||||
|
|
||||||
|
def safe_sqrt(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute the square root of a tensor ensuring that the input is non-negative"""
|
||||||
|
return torch.sqrt(torch.maximum(x, torch.tensor(0)))
|
||||||
|
|
||||||
|
|
||||||
|
class SolverTensors(NamedTuple):
|
||||||
|
cumulative_scale_factors: torch.Tensor
|
||||||
|
noise_std: torch.Tensor
|
||||||
|
signal_to_noise_ratios: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class DPMSolver(Solver):
|
class DPMSolver(Solver):
|
||||||
"""Diffusion probabilistic models (DPMs) solver.
|
"""Diffusion probabilistic models (DPMs) solver.
|
||||||
|
|
||||||
|
@ -37,9 +54,9 @@ class DPMSolver(Solver):
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
params: BaseSolverParams | None = None,
|
params: BaseSolverParams | None = None,
|
||||||
last_step_first_order: bool = False,
|
last_step_first_order: bool = False,
|
||||||
device: Device | str = "cpu",
|
device: torch.device | str = "cpu",
|
||||||
dtype: Dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
):
|
) -> None:
|
||||||
"""Initializes a new DPM solver.
|
"""Initializes a new DPM solver.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -64,6 +81,14 @@ class DPMSolver(Solver):
|
||||||
)
|
)
|
||||||
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
|
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
|
||||||
self.last_step_first_order = last_step_first_order
|
self.last_step_first_order = last_step_first_order
|
||||||
|
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||||
|
self.sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)
|
||||||
|
sigma_min = sigmas[0:1] # corresponds to `final_sigmas_type="sigma_min" in diffusers`
|
||||||
|
self.sigmas = torch.cat([self.sigmas, sigma_min])
|
||||||
|
self.cumulative_scale_factors, self.noise_std, self.signal_to_noise_ratios = self._solver_tensors_from_sigmas(
|
||||||
|
self.sigmas
|
||||||
|
)
|
||||||
|
self.timesteps = self._timesteps_from_sigmas(sigmas)
|
||||||
|
|
||||||
def rebuild(
|
def rebuild(
|
||||||
self: "DPMSolver",
|
self: "DPMSolver",
|
||||||
|
@ -83,7 +108,7 @@ class DPMSolver(Solver):
|
||||||
r.last_step_first_order = self.last_step_first_order
|
r.last_step_first_order = self.last_step_first_order
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def _generate_timesteps(self) -> Tensor:
|
def _generate_timesteps(self) -> torch.Tensor:
|
||||||
if self.params.timesteps_spacing != TimestepSpacing.CUSTOM:
|
if self.params.timesteps_spacing != TimestepSpacing.CUSTOM:
|
||||||
return super()._generate_timesteps()
|
return super()._generate_timesteps()
|
||||||
|
|
||||||
|
@ -96,9 +121,75 @@ class DPMSolver(Solver):
|
||||||
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
|
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
|
||||||
return torch.tensor(np_space).flip(0)
|
return torch.tensor(np_space).flip(0)
|
||||||
|
|
||||||
|
def _generate_sigmas(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Generate the sigmas used by the solver."""
|
||||||
|
assert self.params.sigma_schedule is not None, "sigma_schedule must be set for the DPM solver"
|
||||||
|
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||||
|
sigmas = sigmas.flip(0)
|
||||||
|
rescaled_sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)
|
||||||
|
rescaled_sigmas = torch.cat([rescaled_sigmas, torch.tensor([0.0])])
|
||||||
|
return sigmas, rescaled_sigmas
|
||||||
|
|
||||||
|
def _rescale_sigmas(self, sigmas: torch.Tensor, sigma_schedule: NoiseSchedule | None) -> torch.Tensor:
|
||||||
|
"""Rescale the sigmas according to the sigma schedule."""
|
||||||
|
match sigma_schedule:
|
||||||
|
case NoiseSchedule.UNIFORM:
|
||||||
|
rho = 1
|
||||||
|
case NoiseSchedule.QUADRATIC:
|
||||||
|
rho = 2
|
||||||
|
case NoiseSchedule.KARRAS:
|
||||||
|
rho = 7
|
||||||
|
case None:
|
||||||
|
return torch.tensor(
|
||||||
|
np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()),
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device)
|
||||||
|
first_sigma = sigmas[0]
|
||||||
|
last_sigma = sigmas[-1]
|
||||||
|
rescaled_sigmas = (
|
||||||
|
first_sigma ** (1 / rho) + linear_schedule * (last_sigma ** (1 / rho) - first_sigma ** (1 / rho))
|
||||||
|
) ** rho
|
||||||
|
return rescaled_sigmas.flip(0)
|
||||||
|
|
||||||
|
def _timesteps_from_sigmas(self, sigmas: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Generate the timesteps from the sigmas."""
|
||||||
|
log_sigmas = safe_log(sigmas)
|
||||||
|
timesteps: list[torch.Tensor] = []
|
||||||
|
for sigma in self.sigmas[:-1]:
|
||||||
|
log_sigma = safe_log(sigma)
|
||||||
|
distance_matrix = log_sigma - log_sigmas.unsqueeze(1)
|
||||||
|
|
||||||
|
# Determine the range of sigma indices
|
||||||
|
low_indices = (distance_matrix >= 0).cumsum(dim=0).argmax(dim=0).clip(max=sigmas.size(0) - 2)
|
||||||
|
high_indices = low_indices + 1
|
||||||
|
|
||||||
|
low_log_sigma = log_sigmas[low_indices]
|
||||||
|
high_log_sigma = log_sigmas[high_indices]
|
||||||
|
|
||||||
|
# Interpolate sigma values
|
||||||
|
interpolation_weights = (low_log_sigma - log_sigma) / (low_log_sigma - high_log_sigma)
|
||||||
|
interpolation_weights = torch.clamp(interpolation_weights, 0, 1)
|
||||||
|
timestep = (1 - interpolation_weights) * low_indices + interpolation_weights * high_indices
|
||||||
|
timesteps.append(timestep)
|
||||||
|
|
||||||
|
return torch.cat(timesteps).round()
|
||||||
|
|
||||||
|
def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors:
|
||||||
|
"""Generate the tensors from the sigmas."""
|
||||||
|
cumulative_scale_factors = 1 / torch.sqrt(sigmas**2 + 1)
|
||||||
|
noise_std = sigmas * cumulative_scale_factors
|
||||||
|
signal_to_noise_ratios = safe_log(cumulative_scale_factors) - safe_log(noise_std)
|
||||||
|
return SolverTensors(
|
||||||
|
cumulative_scale_factors=cumulative_scale_factors,
|
||||||
|
noise_std=noise_std,
|
||||||
|
signal_to_noise_ratios=signal_to_noise_ratios,
|
||||||
|
)
|
||||||
|
|
||||||
def dpm_solver_first_order_update(
|
def dpm_solver_first_order_update(
|
||||||
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
|
self, x: torch.Tensor, noise: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None
|
||||||
) -> Tensor:
|
) -> torch.Tensor:
|
||||||
"""Applies a first-order backward Euler update to the input data `x`.
|
"""Applies a first-order backward Euler update to the input data `x`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -109,32 +200,29 @@ class DPMSolver(Solver):
|
||||||
Returns:
|
Returns:
|
||||||
The denoised version of the input data `x`.
|
The denoised version of the input data `x`.
|
||||||
"""
|
"""
|
||||||
current_timestep = self.timesteps[step]
|
current_ratio = self.signal_to_noise_ratios[step]
|
||||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
|
next_ratio = self.signal_to_noise_ratios[step + 1]
|
||||||
|
|
||||||
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
|
next_scale_factor = self.cumulative_scale_factors[step + 1]
|
||||||
current_ratio = self.signal_to_noise_ratios[current_timestep]
|
|
||||||
|
|
||||||
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
|
next_noise_std = self.noise_std[step + 1]
|
||||||
|
current_noise_std = self.noise_std[step]
|
||||||
|
|
||||||
previous_noise_std = self.noise_std[previous_timestep]
|
ratio_delta = current_ratio - next_ratio
|
||||||
current_noise_std = self.noise_std[current_timestep]
|
|
||||||
|
|
||||||
ratio_delta = current_ratio - previous_ratio
|
|
||||||
|
|
||||||
if sde_noise is None:
|
if sde_noise is None:
|
||||||
return (previous_noise_std / current_noise_std) * x + (
|
return (next_noise_std / current_noise_std) * x + (1.0 - torch.exp(ratio_delta)) * next_scale_factor * noise
|
||||||
1.0 - torch.exp(ratio_delta)
|
|
||||||
) * previous_scale_factor * noise
|
|
||||||
|
|
||||||
factor = 1.0 - torch.exp(2.0 * ratio_delta)
|
factor = 1.0 - torch.exp(2.0 * ratio_delta)
|
||||||
return (
|
return (
|
||||||
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
|
(next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
|
||||||
+ previous_scale_factor * factor * noise
|
+ next_scale_factor * factor * noise
|
||||||
+ previous_noise_std * torch.sqrt(factor) * sde_noise
|
+ next_noise_std * safe_sqrt(factor) * sde_noise
|
||||||
)
|
)
|
||||||
|
|
||||||
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noise: Tensor | None = None) -> Tensor:
|
def multistep_dpm_solver_second_order_update(
|
||||||
|
self, x: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
"""Applies a second-order backward Euler update to the input data `x`.
|
"""Applies a second-order backward Euler update to the input data `x`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -144,43 +232,41 @@ class DPMSolver(Solver):
|
||||||
Returns:
|
Returns:
|
||||||
The denoised version of the input data `x`.
|
The denoised version of the input data `x`.
|
||||||
"""
|
"""
|
||||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
|
|
||||||
current_timestep = self.timesteps[step]
|
|
||||||
next_timestep = self.timesteps[step - 1]
|
|
||||||
|
|
||||||
current_data_estimation = self.estimated_data[-1]
|
current_data_estimation = self.estimated_data[-1]
|
||||||
next_data_estimation = self.estimated_data[-2]
|
previous_data_estimation = self.estimated_data[-2]
|
||||||
|
|
||||||
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
|
next_ratio = self.signal_to_noise_ratios[step + 1]
|
||||||
current_ratio = self.signal_to_noise_ratios[current_timestep]
|
current_ratio = self.signal_to_noise_ratios[step]
|
||||||
next_ratio = self.signal_to_noise_ratios[next_timestep]
|
previous_ratio = self.signal_to_noise_ratios[step - 1]
|
||||||
|
|
||||||
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
|
next_scale_factor = self.cumulative_scale_factors[step + 1]
|
||||||
previous_noise_std = self.noise_std[previous_timestep]
|
next_noise_std = self.noise_std[step + 1]
|
||||||
current_noise_std = self.noise_std[current_timestep]
|
current_noise_std = self.noise_std[step]
|
||||||
|
|
||||||
estimation_delta = (current_data_estimation - next_data_estimation) / (
|
estimation_delta = (current_data_estimation - previous_data_estimation) / (
|
||||||
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
|
(current_ratio - previous_ratio) / (next_ratio - current_ratio)
|
||||||
)
|
)
|
||||||
ratio_delta = current_ratio - previous_ratio
|
ratio_delta = current_ratio - next_ratio
|
||||||
|
|
||||||
if sde_noise is None:
|
if sde_noise is None:
|
||||||
factor = 1.0 - torch.exp(ratio_delta)
|
factor = 1.0 - torch.exp(ratio_delta)
|
||||||
return (
|
return (
|
||||||
(previous_noise_std / current_noise_std) * x
|
(next_noise_std / current_noise_std) * x
|
||||||
+ previous_scale_factor * factor * current_data_estimation
|
+ next_scale_factor * factor * current_data_estimation
|
||||||
+ 0.5 * previous_scale_factor * factor * estimation_delta
|
+ 0.5 * next_scale_factor * factor * estimation_delta
|
||||||
)
|
)
|
||||||
|
|
||||||
factor = 1.0 - torch.exp(2.0 * ratio_delta)
|
factor = 1.0 - torch.exp(2.0 * ratio_delta)
|
||||||
return (
|
return (
|
||||||
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
|
(next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
|
||||||
+ previous_scale_factor * factor * current_data_estimation
|
+ next_scale_factor * factor * current_data_estimation
|
||||||
+ 0.5 * previous_scale_factor * factor * estimation_delta
|
+ 0.5 * next_scale_factor * factor * estimation_delta
|
||||||
+ previous_noise_std * torch.sqrt(factor) * sde_noise
|
+ next_noise_std * safe_sqrt(factor) * sde_noise
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
def __call__(
|
||||||
|
self, x: torch.Tensor, predicted_noise: torch.Tensor, step: int, generator: torch.Generator | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
"""Apply one step of the backward diffusion process.
|
"""Apply one step of the backward diffusion process.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
|
@ -199,9 +285,8 @@ class DPMSolver(Solver):
|
||||||
"""
|
"""
|
||||||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||||
|
|
||||||
current_timestep = self.timesteps[step]
|
scale_factor = self.cumulative_scale_factors[step]
|
||||||
scale_factor = self.cumulative_scale_factors[current_timestep]
|
noise_ratio = self.noise_std[step]
|
||||||
noise_ratio = self.noise_std[current_timestep]
|
|
||||||
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
|
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
|
||||||
self.estimated_data.append(estimated_denoised_data)
|
self.estimated_data.append(estimated_denoised_data)
|
||||||
variance = self.params.sde_variance
|
variance = self.params.sde_variance
|
||||||
|
|
|
@ -67,6 +67,7 @@ class BaseSolverParams:
|
||||||
initial_diffusion_rate: float | None
|
initial_diffusion_rate: float | None
|
||||||
final_diffusion_rate: float | None
|
final_diffusion_rate: float | None
|
||||||
noise_schedule: NoiseSchedule | None
|
noise_schedule: NoiseSchedule | None
|
||||||
|
sigma_schedule: NoiseSchedule | None
|
||||||
model_prediction_type: ModelPredictionType | None
|
model_prediction_type: ModelPredictionType | None
|
||||||
sde_variance: float
|
sde_variance: float
|
||||||
|
|
||||||
|
@ -91,6 +92,7 @@ class SolverParams(BaseSolverParams):
|
||||||
initial_diffusion_rate: float | None = None
|
initial_diffusion_rate: float | None = None
|
||||||
final_diffusion_rate: float | None = None
|
final_diffusion_rate: float | None = None
|
||||||
noise_schedule: NoiseSchedule | None = None
|
noise_schedule: NoiseSchedule | None = None
|
||||||
|
sigma_schedule: NoiseSchedule | None = None
|
||||||
model_prediction_type: ModelPredictionType | None = None
|
model_prediction_type: ModelPredictionType | None = None
|
||||||
sde_variance: float = 0.0
|
sde_variance: float = 0.0
|
||||||
|
|
||||||
|
@ -103,6 +105,7 @@ class ResolvedSolverParams(BaseSolverParams):
|
||||||
initial_diffusion_rate: float
|
initial_diffusion_rate: float
|
||||||
final_diffusion_rate: float
|
final_diffusion_rate: float
|
||||||
noise_schedule: NoiseSchedule
|
noise_schedule: NoiseSchedule
|
||||||
|
sigma_schedule: NoiseSchedule | None
|
||||||
model_prediction_type: ModelPredictionType
|
model_prediction_type: ModelPredictionType
|
||||||
sde_variance: float
|
sde_variance: float
|
||||||
|
|
||||||
|
@ -140,6 +143,7 @@ class Solver(fl.Module, ABC):
|
||||||
initial_diffusion_rate=8.5e-4,
|
initial_diffusion_rate=8.5e-4,
|
||||||
final_diffusion_rate=1.2e-2,
|
final_diffusion_rate=1.2e-2,
|
||||||
noise_schedule=NoiseSchedule.QUADRATIC,
|
noise_schedule=NoiseSchedule.QUADRATIC,
|
||||||
|
sigma_schedule=None,
|
||||||
model_prediction_type=ModelPredictionType.NOISE,
|
model_prediction_type=ModelPredictionType.NOISE,
|
||||||
sde_variance=0.0,
|
sde_variance=0.0,
|
||||||
)
|
)
|
||||||
|
@ -404,14 +408,12 @@ class Solver(fl.Module, ABC):
|
||||||
A tensor representing the noise schedule.
|
A tensor representing the noise schedule.
|
||||||
"""
|
"""
|
||||||
match self.params.noise_schedule:
|
match self.params.noise_schedule:
|
||||||
case "uniform":
|
case NoiseSchedule.UNIFORM:
|
||||||
return 1 - self.sample_power_distribution(1)
|
return 1 - self.sample_power_distribution(1)
|
||||||
case "quadratic":
|
case NoiseSchedule.QUADRATIC:
|
||||||
return 1 - self.sample_power_distribution(2)
|
return 1 - self.sample_power_distribution(2)
|
||||||
case "karras":
|
case NoiseSchedule.KARRAS:
|
||||||
return 1 - self.sample_power_distribution(7)
|
return 1 - self.sample_power_distribution(7)
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unknown noise schedule: {self.params.noise_schedule}")
|
|
||||||
|
|
||||||
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
|
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
|
||||||
"""Move the solver to the specified device and data type.
|
"""Move the solver to the specified device and data type.
|
||||||
|
|
|
@ -97,6 +97,11 @@ def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image:
|
||||||
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB")
|
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_image_std_sde_karras_random_init(ref_path: Path) -> Image.Image:
|
||||||
|
return _img_open(ref_path / "expected_std_sde_karras_random_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
|
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
|
||||||
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
|
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
|
||||||
|
@ -913,6 +918,39 @@ def test_diffusion_std_sde_random_init(
|
||||||
ensure_similar_images(predicted_image, expected_image_std_sde_random_init)
|
ensure_similar_images(predicted_image, expected_image_std_sde_random_init)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_diffusion_std_sde_karras_random_init(
|
||||||
|
sd15_std_sde: StableDiffusion_1, expected_image_std_sde_karras_random_init: Image.Image, test_device: torch.device
|
||||||
|
):
|
||||||
|
sd15 = sd15_std_sde
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
|
|
||||||
|
sd15.solver = DPMSolver(
|
||||||
|
num_inference_steps=18,
|
||||||
|
last_step_first_order=True,
|
||||||
|
params=SolverParams(sde_variance=1.0, sigma_schedule=NoiseSchedule.KARRAS),
|
||||||
|
device=test_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = sd15.init_latents((512, 512))
|
||||||
|
|
||||||
|
for step in sd15.steps:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
predicted_image = sd15.lda.latents_to_image(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image_std_sde_karras_random_init)
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_batch2(sd15_std: StableDiffusion_1):
|
def test_diffusion_batch2(sd15_std: StableDiffusion_1):
|
||||||
sd15 = sd15_std
|
sd15 = sd15_std
|
||||||
|
|
|
@ -97,6 +97,29 @@ manual_seed(2)
|
||||||
image = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=7.5).images[0]
|
image = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=7.5).images[0]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- `expected_std_sde_karras_random_init.png` is generated with the following code (diffusers 0.30.2):
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
||||||
|
from refiners.fluxion.utils import manual_seed
|
||||||
|
|
||||||
|
model_id = "botp/stable-diffusion-v1-5"
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
|
||||||
|
pipe = pipe.to("cuda:1")
|
||||||
|
|
||||||
|
config = {**pipe.scheduler.config}
|
||||||
|
config["use_karras_sigmas"] = True
|
||||||
|
config["algorithm_type"] = "sde-dpmsolver++"
|
||||||
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(config)
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
manual_seed(2)
|
||||||
|
image = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=18, guidance_scale=7.5).images[0]
|
||||||
|
```
|
||||||
|
|
||||||
- `kitchen_mask.png` is made manually.
|
- `kitchen_mask.png` is made manually.
|
||||||
|
|
||||||
- Controlnet guides have been manually generated (x) using open source software and models, namely:
|
- Controlnet guides have been manually generated (x) using open source software and models, namely:
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 378 KiB |
|
@ -1,3 +1,4 @@
|
||||||
|
import itertools
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
|
@ -29,8 +30,11 @@ def test_ddpm_diffusers():
|
||||||
assert equal(diffusers_scheduler.timesteps, solver.timesteps)
|
assert equal(diffusers_scheduler.timesteps, solver.timesteps)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
|
@pytest.mark.parametrize(
|
||||||
def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
|
"n_steps, last_step_first_order, sde_variance, use_karras_sigmas",
|
||||||
|
list(itertools.product([5, 30], [False, True], [0.0, 1.0], [False, True])),
|
||||||
|
)
|
||||||
|
def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool, sde_variance: float, use_karras_sigmas: bool):
|
||||||
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
|
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
|
||||||
|
|
||||||
manual_seed(0)
|
manual_seed(0)
|
||||||
|
@ -42,43 +46,17 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
|
||||||
lower_order_final=False,
|
lower_order_final=False,
|
||||||
euler_at_final=last_step_first_order,
|
euler_at_final=last_step_first_order,
|
||||||
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
|
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
|
||||||
|
algorithm_type="sde-dpmsolver++" if sde_variance == 1.0 else "dpmsolver++",
|
||||||
|
use_karras_sigmas=use_karras_sigmas,
|
||||||
)
|
)
|
||||||
diffusers_scheduler.set_timesteps(n_steps)
|
diffusers_scheduler.set_timesteps(n_steps)
|
||||||
solver = DPMSolver(
|
solver = DPMSolver(
|
||||||
num_inference_steps=n_steps,
|
num_inference_steps=n_steps,
|
||||||
last_step_first_order=last_step_first_order,
|
last_step_first_order=last_step_first_order,
|
||||||
)
|
params=SolverParams(
|
||||||
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
|
sde_variance=sde_variance,
|
||||||
|
sigma_schedule=NoiseSchedule.KARRAS if use_karras_sigmas else None,
|
||||||
sample = randn(1, 3, 32, 32)
|
),
|
||||||
predicted_noise = randn(1, 3, 32, 32)
|
|
||||||
|
|
||||||
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
|
||||||
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
|
|
||||||
refiners_output = solver(x=sample, predicted_noise=predicted_noise, step=step)
|
|
||||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
|
|
||||||
def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool):
|
|
||||||
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
|
|
||||||
|
|
||||||
manual_seed(0)
|
|
||||||
|
|
||||||
diffusers_scheduler = DiffuserScheduler(
|
|
||||||
beta_schedule="scaled_linear",
|
|
||||||
beta_start=0.00085,
|
|
||||||
beta_end=0.012,
|
|
||||||
lower_order_final=False,
|
|
||||||
euler_at_final=last_step_first_order,
|
|
||||||
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
|
|
||||||
algorithm_type="sde-dpmsolver++",
|
|
||||||
)
|
|
||||||
diffusers_scheduler.set_timesteps(n_steps)
|
|
||||||
solver = DPMSolver(
|
|
||||||
num_inference_steps=n_steps,
|
|
||||||
last_step_first_order=last_step_first_order,
|
|
||||||
params=SolverParams(sde_variance=1.0),
|
|
||||||
)
|
)
|
||||||
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
|
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
|
||||||
|
|
||||||
|
@ -94,8 +72,9 @@ def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool):
|
||||||
manual_seed(37)
|
manual_seed(37)
|
||||||
refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)]
|
refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)]
|
||||||
|
|
||||||
|
atol = 1e-4 if use_karras_sigmas else 1e-6
|
||||||
for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)):
|
for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)):
|
||||||
assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=1e-6), f"outputs differ at step {step}"
|
assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=atol), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
def test_ddim_diffusers():
|
def test_ddim_diffusers():
|
||||||
|
|
Loading…
Reference in a new issue