Add stochastic sampling to DPM solver (SDE)
Some checks are pending
CI / lint_and_typecheck (push) Waiting to run
Deploy docs to GitHub Pages / Deploy docs (push) Waiting to run
Spell checker / Spell check (push) Waiting to run

This commit is contained in:
limiteinductive 2024-07-23 08:52:40 +00:00 committed by Benjamin Trom
parent daee77298d
commit 09a9dfd494
8 changed files with 188 additions and 15 deletions

View file

@ -41,6 +41,8 @@ class DDIM(Solver):
""" """
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None): if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
raise NotImplementedError raise NotImplementedError
if params and params.sde_variance != 0.0:
raise NotImplementedError("DDIM does not support sde_variance != 0.0 yet")
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,

View file

@ -2,7 +2,8 @@ import dataclasses
from collections import deque from collections import deque
import numpy as np import numpy as np
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import ( from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams, BaseSolverParams,
@ -51,6 +52,8 @@ class DPMSolver(Solver):
""" """
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None): if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
raise NotImplementedError raise NotImplementedError
if params and params.sde_variance not in (0.0, 1.0):
raise NotImplementedError("DPMSolver only supports sde_variance=0.0 or 1.0")
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
@ -93,7 +96,9 @@ class DPMSolver(Solver):
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:] np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
return tensor(np_space).flip(0) return tensor(np_space).flip(0)
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor: def dpm_solver_first_order_update(
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
) -> Tensor:
"""Applies a first-order backward Euler update to the input data `x`. """Applies a first-order backward Euler update to the input data `x`.
Args: Args:
@ -115,11 +120,21 @@ class DPMSolver(Solver):
previous_noise_std = self.noise_std[previous_timestep] previous_noise_std = self.noise_std[previous_timestep]
current_noise_std = self.noise_std[current_timestep] current_noise_std = self.noise_std[current_timestep]
factor = exp(-(previous_ratio - current_ratio)) - 1.0 ratio_delta = current_ratio - previous_ratio
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
return denoised_x
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor: if sde_noise is None:
return (previous_noise_std / current_noise_std) * x + (
1.0 - torch.exp(ratio_delta)
) * previous_scale_factor * noise
factor = 1.0 - torch.exp(2.0 * ratio_delta)
return (
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
+ previous_scale_factor * factor * noise
+ previous_noise_std * torch.sqrt(factor) * sde_noise
)
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noise: Tensor | None = None) -> Tensor:
"""Applies a second-order backward Euler update to the input data `x`. """Applies a second-order backward Euler update to the input data `x`.
Args: Args:
@ -147,13 +162,23 @@ class DPMSolver(Solver):
estimation_delta = (current_data_estimation - next_data_estimation) / ( estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio) (current_ratio - next_ratio) / (previous_ratio - current_ratio)
) )
factor = exp(-(previous_ratio - current_ratio)) - 1.0 ratio_delta = current_ratio - previous_ratio
denoised_x = (
if sde_noise is None:
factor = 1.0 - torch.exp(ratio_delta)
return (
(previous_noise_std / current_noise_std) * x (previous_noise_std / current_noise_std) * x
- (factor * previous_scale_factor) * current_data_estimation + previous_scale_factor * factor * current_data_estimation
- 0.5 * (factor * previous_scale_factor) * estimation_delta + 0.5 * previous_scale_factor * factor * estimation_delta
)
factor = 1.0 - torch.exp(2.0 * ratio_delta)
return (
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
+ previous_scale_factor * factor * current_data_estimation
+ 0.5 * previous_scale_factor * factor * estimation_delta
+ previous_noise_std * torch.sqrt(factor) * sde_noise
) )
return denoised_x
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""Apply one step of the backward diffusion process. """Apply one step of the backward diffusion process.
@ -175,11 +200,20 @@ class DPMSolver(Solver):
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
current_timestep = self.timesteps[step] current_timestep = self.timesteps[step]
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep] scale_factor = self.cumulative_scale_factors[current_timestep]
noise_ratio = self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
self.estimated_data.append(estimated_denoised_data) self.estimated_data.append(estimated_denoised_data)
variance = self.params.sde_variance
sde_noise = (
torch.randn(x.shape, generator=generator, device=x.device, dtype=x.dtype) * variance
if variance > 0.0
else None
)
if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1): if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1):
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step) return self.dpm_solver_first_order_update(
x=x, noise=estimated_denoised_data, step=step, sde_noise=sde_noise
)
return self.multistep_dpm_solver_second_order_update(x=x, step=step) return self.multistep_dpm_solver_second_order_update(x=x, step=step, sde_noise=sde_noise)

View file

@ -36,6 +36,8 @@ class Euler(Solver):
""" """
if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None): if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None):
raise NotImplementedError raise NotImplementedError
if params and params.sde_variance != 0.0:
raise NotImplementedError("Euler does not support sde_variance != 0.0 yet")
super().__init__( super().__init__(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,

View file

@ -79,6 +79,7 @@ class BaseSolverParams:
final_diffusion_rate: float | None final_diffusion_rate: float | None
noise_schedule: NoiseSchedule | None noise_schedule: NoiseSchedule | None
model_prediction_type: ModelPredictionType | None model_prediction_type: ModelPredictionType | None
sde_variance: float
@dataclasses.dataclass(kw_only=True, frozen=True) @dataclasses.dataclass(kw_only=True, frozen=True)
@ -102,6 +103,7 @@ class SolverParams(BaseSolverParams):
final_diffusion_rate: float | None = None final_diffusion_rate: float | None = None
noise_schedule: NoiseSchedule | None = None noise_schedule: NoiseSchedule | None = None
model_prediction_type: ModelPredictionType | None = None model_prediction_type: ModelPredictionType | None = None
sde_variance: float = 0.0
@dataclasses.dataclass(kw_only=True, frozen=True) @dataclasses.dataclass(kw_only=True, frozen=True)
@ -113,6 +115,7 @@ class ResolvedSolverParams(BaseSolverParams):
final_diffusion_rate: float final_diffusion_rate: float
noise_schedule: NoiseSchedule noise_schedule: NoiseSchedule
model_prediction_type: ModelPredictionType model_prediction_type: ModelPredictionType
sde_variance: float
class Solver(fl.Module, ABC): class Solver(fl.Module, ABC):
@ -123,6 +126,19 @@ class Solver(fl.Module, ABC):
This process is described using several parameters such as initial and final diffusion rates, This process is described using several parameters such as initial and final diffusion rates,
and is encapsulated into a `__call__` method that applies a step of the diffusion process. and is encapsulated into a `__call__` method that applies a step of the diffusion process.
Attributes:
params: The common parameters for solvers. See `SolverParams`.
num_inference_steps: The number of inference steps to perform.
first_inference_step: The step to start the inference process from.
scale_factors: The scale factors used to denoise the input. These are called "betas" in other implementations,
and `1 - scale_factors` is called "alphas".
cumulative_scale_factors: The cumulative scale factors used to denoise the input. These are called "alpha_t" in
other implementations.
noise_std: The standard deviation of the noise used to denoise the input. This is called "sigma_t" in other
implementations.
signal_to_noise_ratios: The signal-to-noise ratios used to denoise the input. This is called "lambda_t" in other
implementations.
""" """
timesteps: Tensor timesteps: Tensor
@ -136,6 +152,7 @@ class Solver(fl.Module, ABC):
final_diffusion_rate=1.2e-2, final_diffusion_rate=1.2e-2,
noise_schedule=NoiseSchedule.QUADRATIC, noise_schedule=NoiseSchedule.QUADRATIC,
model_prediction_type=ModelPredictionType.NOISE, model_prediction_type=ModelPredictionType.NOISE,
sde_variance=0.0,
) )
def __init__( def __init__(

View file

@ -88,6 +88,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init.png").convert("RGB") return _img_open(ref_path / "expected_std_random_init.png").convert("RGB")
@pytest.fixture
def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image: def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB") return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
@ -560,6 +565,24 @@ def sd15_std(
return sd15 return sd15
@pytest.fixture
def sd15_std_sde(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
sde_solver = DPMSolver(num_inference_steps=30, last_step_first_order=True, params=SolverParams(sde_variance=1.0))
sd15 = StableDiffusion_1(device=test_device, solver=sde_solver)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
return sd15
@pytest.fixture @pytest.fixture
def sd15_std_float16( def sd15_std_float16(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
@ -831,6 +854,33 @@ def test_diffusion_std_random_init(
ensure_similar_images(predicted_image, expected_image_std_random_init) ensure_similar_images(predicted_image, expected_image_std_random_init)
@no_grad()
def test_diffusion_std_sde_random_init(
sd15_std_sde: StableDiffusion_1, expected_image_std_sde_random_init: Image.Image, test_device: torch.device
):
sd15 = sd15_std_sde
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_inference_steps(50)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_sde_random_init)
@no_grad() @no_grad()
def test_diffusion_batch2(sd15_std: StableDiffusion_1): def test_diffusion_batch2(sd15_std: StableDiffusion_1):
sd15 = sd15_std sd15 = sd15_std

View file

@ -67,6 +67,35 @@ Special cases:
- `kitchen_dog.png` is generated with the same Diffusers script and negative prompt, seed 12, positive prompt "a small brown dog, detailed high-quality professional image, sitting on a chair, in a kitchen". - `kitchen_dog.png` is generated with the same Diffusers script and negative prompt, seed 12, positive prompt "a small brown dog, detailed high-quality professional image, sitting on a chair, in a kitchen".
- `expected_std_sde_random_init.png` is generated with the following code:
```python
import torch
from diffusers import StableDiffusionPipeline
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from refiners.fluxion.utils import manual_seed
diffusers_solver = DPMSolverMultistepScheduler.from_config( # type: ignore
{
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"algorithm_type": "sde-dpmsolver++",
"use_karras_sigmas": False,
"final_sigmas_type": "sigma_min",
"euler_at_final": True,
}
)
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, scheduler=diffusers_solver)
pipe = pipe.to("cuda")
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
manual_seed(2)
image = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=7.5).images[0]
```
- `kitchen_mask.png` is made manually. - `kitchen_mask.png` is made manually.
- Controlnet guides have been manually generated (x) using open source software and models, namely: - Controlnet guides have been manually generated (x) using open source software and models, namely:

Binary file not shown.

After

Width:  |  Height:  |  Size: 349 KiB

View file

@ -59,6 +59,45 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}" assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool):
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
manual_seed(0)
diffusers_scheduler = DiffuserScheduler(
beta_schedule="scaled_linear",
beta_start=0.00085,
beta_end=0.012,
lower_order_final=False,
euler_at_final=last_step_first_order,
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
algorithm_type="sde-dpmsolver++",
)
diffusers_scheduler.set_timesteps(n_steps)
solver = DPMSolver(
num_inference_steps=n_steps,
last_step_first_order=last_step_first_order,
params=SolverParams(sde_variance=1.0),
)
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 3, 32, 32)
predicted_noise = randn(1, 3, 32, 32)
manual_seed(37)
diffusers_outputs: list[Tensor] = [
cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
for timestep in diffusers_scheduler.timesteps
]
manual_seed(37)
refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)]
for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)):
assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=1e-6), f"outputs differ at step {step}"
def test_ddim_diffusers(): def test_ddim_diffusers():
from diffusers import DDIMScheduler # type: ignore from diffusers import DDIMScheduler # type: ignore