Add stochastic sampling to DPM solver (SDE)
Some checks are pending
CI / lint_and_typecheck (push) Waiting to run
Deploy docs to GitHub Pages / Deploy docs (push) Waiting to run
Spell checker / Spell check (push) Waiting to run

This commit is contained in:
limiteinductive 2024-07-23 08:52:40 +00:00 committed by Benjamin Trom
parent daee77298d
commit 09a9dfd494
8 changed files with 188 additions and 15 deletions

View file

@ -41,6 +41,8 @@ class DDIM(Solver):
"""
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
raise NotImplementedError
if params and params.sde_variance != 0.0:
raise NotImplementedError("DDIM does not support sde_variance != 0.0 yet")
super().__init__(
num_inference_steps=num_inference_steps,

View file

@ -2,7 +2,8 @@ import dataclasses
from collections import deque
import numpy as np
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
@ -51,6 +52,8 @@ class DPMSolver(Solver):
"""
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
raise NotImplementedError
if params and params.sde_variance not in (0.0, 1.0):
raise NotImplementedError("DPMSolver only supports sde_variance=0.0 or 1.0")
super().__init__(
num_inference_steps=num_inference_steps,
@ -93,7 +96,9 @@ class DPMSolver(Solver):
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
return tensor(np_space).flip(0)
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
def dpm_solver_first_order_update(
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
) -> Tensor:
"""Applies a first-order backward Euler update to the input data `x`.
Args:
@ -115,11 +120,21 @@ class DPMSolver(Solver):
previous_noise_std = self.noise_std[previous_timestep]
current_noise_std = self.noise_std[current_timestep]
factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
return denoised_x
ratio_delta = current_ratio - previous_ratio
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
if sde_noise is None:
return (previous_noise_std / current_noise_std) * x + (
1.0 - torch.exp(ratio_delta)
) * previous_scale_factor * noise
factor = 1.0 - torch.exp(2.0 * ratio_delta)
return (
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
+ previous_scale_factor * factor * noise
+ previous_noise_std * torch.sqrt(factor) * sde_noise
)
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noise: Tensor | None = None) -> Tensor:
"""Applies a second-order backward Euler update to the input data `x`.
Args:
@ -147,13 +162,23 @@ class DPMSolver(Solver):
estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
)
factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (
(previous_noise_std / current_noise_std) * x
- (factor * previous_scale_factor) * current_data_estimation
- 0.5 * (factor * previous_scale_factor) * estimation_delta
ratio_delta = current_ratio - previous_ratio
if sde_noise is None:
factor = 1.0 - torch.exp(ratio_delta)
return (
(previous_noise_std / current_noise_std) * x
+ previous_scale_factor * factor * current_data_estimation
+ 0.5 * previous_scale_factor * factor * estimation_delta
)
factor = 1.0 - torch.exp(2.0 * ratio_delta)
return (
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
+ previous_scale_factor * factor * current_data_estimation
+ 0.5 * previous_scale_factor * factor * estimation_delta
+ previous_noise_std * torch.sqrt(factor) * sde_noise
)
return denoised_x
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""Apply one step of the backward diffusion process.
@ -175,11 +200,20 @@ class DPMSolver(Solver):
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
current_timestep = self.timesteps[step]
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
scale_factor = self.cumulative_scale_factors[current_timestep]
noise_ratio = self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
self.estimated_data.append(estimated_denoised_data)
variance = self.params.sde_variance
sde_noise = (
torch.randn(x.shape, generator=generator, device=x.device, dtype=x.dtype) * variance
if variance > 0.0
else None
)
if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1):
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
return self.dpm_solver_first_order_update(
x=x, noise=estimated_denoised_data, step=step, sde_noise=sde_noise
)
return self.multistep_dpm_solver_second_order_update(x=x, step=step)
return self.multistep_dpm_solver_second_order_update(x=x, step=step, sde_noise=sde_noise)

View file

@ -36,6 +36,8 @@ class Euler(Solver):
"""
if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None):
raise NotImplementedError
if params and params.sde_variance != 0.0:
raise NotImplementedError("Euler does not support sde_variance != 0.0 yet")
super().__init__(
num_inference_steps=num_inference_steps,

View file

@ -79,6 +79,7 @@ class BaseSolverParams:
final_diffusion_rate: float | None
noise_schedule: NoiseSchedule | None
model_prediction_type: ModelPredictionType | None
sde_variance: float
@dataclasses.dataclass(kw_only=True, frozen=True)
@ -102,6 +103,7 @@ class SolverParams(BaseSolverParams):
final_diffusion_rate: float | None = None
noise_schedule: NoiseSchedule | None = None
model_prediction_type: ModelPredictionType | None = None
sde_variance: float = 0.0
@dataclasses.dataclass(kw_only=True, frozen=True)
@ -113,6 +115,7 @@ class ResolvedSolverParams(BaseSolverParams):
final_diffusion_rate: float
noise_schedule: NoiseSchedule
model_prediction_type: ModelPredictionType
sde_variance: float
class Solver(fl.Module, ABC):
@ -123,6 +126,19 @@ class Solver(fl.Module, ABC):
This process is described using several parameters such as initial and final diffusion rates,
and is encapsulated into a `__call__` method that applies a step of the diffusion process.
Attributes:
params: The common parameters for solvers. See `SolverParams`.
num_inference_steps: The number of inference steps to perform.
first_inference_step: The step to start the inference process from.
scale_factors: The scale factors used to denoise the input. These are called "betas" in other implementations,
and `1 - scale_factors` is called "alphas".
cumulative_scale_factors: The cumulative scale factors used to denoise the input. These are called "alpha_t" in
other implementations.
noise_std: The standard deviation of the noise used to denoise the input. This is called "sigma_t" in other
implementations.
signal_to_noise_ratios: The signal-to-noise ratios used to denoise the input. This is called "lambda_t" in other
implementations.
"""
timesteps: Tensor
@ -136,6 +152,7 @@ class Solver(fl.Module, ABC):
final_diffusion_rate=1.2e-2,
noise_schedule=NoiseSchedule.QUADRATIC,
model_prediction_type=ModelPredictionType.NOISE,
sde_variance=0.0,
)
def __init__(

View file

@ -88,6 +88,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init.png").convert("RGB")
@pytest.fixture
def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB")
@pytest.fixture
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
@ -560,6 +565,24 @@ def sd15_std(
return sd15
@pytest.fixture
def sd15_std_sde(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
sde_solver = DPMSolver(num_inference_steps=30, last_step_first_order=True, params=SolverParams(sde_variance=1.0))
sd15 = StableDiffusion_1(device=test_device, solver=sde_solver)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
return sd15
@pytest.fixture
def sd15_std_float16(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
@ -831,6 +854,33 @@ def test_diffusion_std_random_init(
ensure_similar_images(predicted_image, expected_image_std_random_init)
@no_grad()
def test_diffusion_std_sde_random_init(
sd15_std_sde: StableDiffusion_1, expected_image_std_sde_random_init: Image.Image, test_device: torch.device
):
sd15 = sd15_std_sde
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_inference_steps(50)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_sde_random_init)
@no_grad()
def test_diffusion_batch2(sd15_std: StableDiffusion_1):
sd15 = sd15_std

View file

@ -67,6 +67,35 @@ Special cases:
- `kitchen_dog.png` is generated with the same Diffusers script and negative prompt, seed 12, positive prompt "a small brown dog, detailed high-quality professional image, sitting on a chair, in a kitchen".
- `expected_std_sde_random_init.png` is generated with the following code:
```python
import torch
from diffusers import StableDiffusionPipeline
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from refiners.fluxion.utils import manual_seed
diffusers_solver = DPMSolverMultistepScheduler.from_config( # type: ignore
{
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"algorithm_type": "sde-dpmsolver++",
"use_karras_sigmas": False,
"final_sigmas_type": "sigma_min",
"euler_at_final": True,
}
)
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, scheduler=diffusers_solver)
pipe = pipe.to("cuda")
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
manual_seed(2)
image = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=7.5).images[0]
```
- `kitchen_mask.png` is made manually.
- Controlnet guides have been manually generated (x) using open source software and models, namely:

Binary file not shown.

After

Width:  |  Height:  |  Size: 349 KiB

View file

@ -59,6 +59,45 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool):
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
manual_seed(0)
diffusers_scheduler = DiffuserScheduler(
beta_schedule="scaled_linear",
beta_start=0.00085,
beta_end=0.012,
lower_order_final=False,
euler_at_final=last_step_first_order,
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
algorithm_type="sde-dpmsolver++",
)
diffusers_scheduler.set_timesteps(n_steps)
solver = DPMSolver(
num_inference_steps=n_steps,
last_step_first_order=last_step_first_order,
params=SolverParams(sde_variance=1.0),
)
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 3, 32, 32)
predicted_noise = randn(1, 3, 32, 32)
manual_seed(37)
diffusers_outputs: list[Tensor] = [
cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
for timestep in diffusers_scheduler.timesteps
]
manual_seed(37)
refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)]
for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)):
assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=1e-6), f"outputs differ at step {step}"
def test_ddim_diffusers():
from diffusers import DDIMScheduler # type: ignore