mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
Add stochastic sampling to DPM solver (SDE)
This commit is contained in:
parent
daee77298d
commit
09a9dfd494
|
@ -41,6 +41,8 @@ class DDIM(Solver):
|
|||
"""
|
||||
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
|
||||
raise NotImplementedError
|
||||
if params and params.sde_variance != 0.0:
|
||||
raise NotImplementedError("DDIM does not support sde_variance != 0.0 yet")
|
||||
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
|
|
|
@ -2,7 +2,8 @@ import dataclasses
|
|||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
||||
import torch
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
BaseSolverParams,
|
||||
|
@ -51,6 +52,8 @@ class DPMSolver(Solver):
|
|||
"""
|
||||
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
|
||||
raise NotImplementedError
|
||||
if params and params.sde_variance not in (0.0, 1.0):
|
||||
raise NotImplementedError("DPMSolver only supports sde_variance=0.0 or 1.0")
|
||||
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
|
@ -93,7 +96,9 @@ class DPMSolver(Solver):
|
|||
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
|
||||
return tensor(np_space).flip(0)
|
||||
|
||||
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||
def dpm_solver_first_order_update(
|
||||
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
|
||||
) -> Tensor:
|
||||
"""Applies a first-order backward Euler update to the input data `x`.
|
||||
|
||||
Args:
|
||||
|
@ -115,11 +120,21 @@ class DPMSolver(Solver):
|
|||
previous_noise_std = self.noise_std[previous_timestep]
|
||||
current_noise_std = self.noise_std[current_timestep]
|
||||
|
||||
factor = exp(-(previous_ratio - current_ratio)) - 1.0
|
||||
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
|
||||
return denoised_x
|
||||
ratio_delta = current_ratio - previous_ratio
|
||||
|
||||
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
|
||||
if sde_noise is None:
|
||||
return (previous_noise_std / current_noise_std) * x + (
|
||||
1.0 - torch.exp(ratio_delta)
|
||||
) * previous_scale_factor * noise
|
||||
|
||||
factor = 1.0 - torch.exp(2.0 * ratio_delta)
|
||||
return (
|
||||
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
|
||||
+ previous_scale_factor * factor * noise
|
||||
+ previous_noise_std * torch.sqrt(factor) * sde_noise
|
||||
)
|
||||
|
||||
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noise: Tensor | None = None) -> Tensor:
|
||||
"""Applies a second-order backward Euler update to the input data `x`.
|
||||
|
||||
Args:
|
||||
|
@ -147,13 +162,23 @@ class DPMSolver(Solver):
|
|||
estimation_delta = (current_data_estimation - next_data_estimation) / (
|
||||
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
|
||||
)
|
||||
factor = exp(-(previous_ratio - current_ratio)) - 1.0
|
||||
denoised_x = (
|
||||
(previous_noise_std / current_noise_std) * x
|
||||
- (factor * previous_scale_factor) * current_data_estimation
|
||||
- 0.5 * (factor * previous_scale_factor) * estimation_delta
|
||||
ratio_delta = current_ratio - previous_ratio
|
||||
|
||||
if sde_noise is None:
|
||||
factor = 1.0 - torch.exp(ratio_delta)
|
||||
return (
|
||||
(previous_noise_std / current_noise_std) * x
|
||||
+ previous_scale_factor * factor * current_data_estimation
|
||||
+ 0.5 * previous_scale_factor * factor * estimation_delta
|
||||
)
|
||||
|
||||
factor = 1.0 - torch.exp(2.0 * ratio_delta)
|
||||
return (
|
||||
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
|
||||
+ previous_scale_factor * factor * current_data_estimation
|
||||
+ 0.5 * previous_scale_factor * factor * estimation_delta
|
||||
+ previous_noise_std * torch.sqrt(factor) * sde_noise
|
||||
)
|
||||
return denoised_x
|
||||
|
||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
"""Apply one step of the backward diffusion process.
|
||||
|
@ -175,11 +200,20 @@ class DPMSolver(Solver):
|
|||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||
|
||||
current_timestep = self.timesteps[step]
|
||||
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
|
||||
scale_factor = self.cumulative_scale_factors[current_timestep]
|
||||
noise_ratio = self.noise_std[current_timestep]
|
||||
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
|
||||
self.estimated_data.append(estimated_denoised_data)
|
||||
variance = self.params.sde_variance
|
||||
sde_noise = (
|
||||
torch.randn(x.shape, generator=generator, device=x.device, dtype=x.dtype) * variance
|
||||
if variance > 0.0
|
||||
else None
|
||||
)
|
||||
|
||||
if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1):
|
||||
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
|
||||
return self.dpm_solver_first_order_update(
|
||||
x=x, noise=estimated_denoised_data, step=step, sde_noise=sde_noise
|
||||
)
|
||||
|
||||
return self.multistep_dpm_solver_second_order_update(x=x, step=step)
|
||||
return self.multistep_dpm_solver_second_order_update(x=x, step=step, sde_noise=sde_noise)
|
||||
|
|
|
@ -36,6 +36,8 @@ class Euler(Solver):
|
|||
"""
|
||||
if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None):
|
||||
raise NotImplementedError
|
||||
if params and params.sde_variance != 0.0:
|
||||
raise NotImplementedError("Euler does not support sde_variance != 0.0 yet")
|
||||
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
|
|
|
@ -79,6 +79,7 @@ class BaseSolverParams:
|
|||
final_diffusion_rate: float | None
|
||||
noise_schedule: NoiseSchedule | None
|
||||
model_prediction_type: ModelPredictionType | None
|
||||
sde_variance: float
|
||||
|
||||
|
||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||
|
@ -102,6 +103,7 @@ class SolverParams(BaseSolverParams):
|
|||
final_diffusion_rate: float | None = None
|
||||
noise_schedule: NoiseSchedule | None = None
|
||||
model_prediction_type: ModelPredictionType | None = None
|
||||
sde_variance: float = 0.0
|
||||
|
||||
|
||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||
|
@ -113,6 +115,7 @@ class ResolvedSolverParams(BaseSolverParams):
|
|||
final_diffusion_rate: float
|
||||
noise_schedule: NoiseSchedule
|
||||
model_prediction_type: ModelPredictionType
|
||||
sde_variance: float
|
||||
|
||||
|
||||
class Solver(fl.Module, ABC):
|
||||
|
@ -123,6 +126,19 @@ class Solver(fl.Module, ABC):
|
|||
|
||||
This process is described using several parameters such as initial and final diffusion rates,
|
||||
and is encapsulated into a `__call__` method that applies a step of the diffusion process.
|
||||
|
||||
Attributes:
|
||||
params: The common parameters for solvers. See `SolverParams`.
|
||||
num_inference_steps: The number of inference steps to perform.
|
||||
first_inference_step: The step to start the inference process from.
|
||||
scale_factors: The scale factors used to denoise the input. These are called "betas" in other implementations,
|
||||
and `1 - scale_factors` is called "alphas".
|
||||
cumulative_scale_factors: The cumulative scale factors used to denoise the input. These are called "alpha_t" in
|
||||
other implementations.
|
||||
noise_std: The standard deviation of the noise used to denoise the input. This is called "sigma_t" in other
|
||||
implementations.
|
||||
signal_to_noise_ratios: The signal-to-noise ratios used to denoise the input. This is called "lambda_t" in other
|
||||
implementations.
|
||||
"""
|
||||
|
||||
timesteps: Tensor
|
||||
|
@ -136,6 +152,7 @@ class Solver(fl.Module, ABC):
|
|||
final_diffusion_rate=1.2e-2,
|
||||
noise_schedule=NoiseSchedule.QUADRATIC,
|
||||
model_prediction_type=ModelPredictionType.NOISE,
|
||||
sde_variance=0.0,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -88,6 +88,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image:
|
|||
return _img_open(ref_path / "expected_std_random_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image:
|
||||
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
|
||||
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
|
||||
|
@ -560,6 +565,24 @@ def sd15_std(
|
|||
return sd15
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sd15_std_sde(
|
||||
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
||||
) -> StableDiffusion_1:
|
||||
if test_device.type == "cpu":
|
||||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
sde_solver = DPMSolver(num_inference_steps=30, last_step_first_order=True, params=SolverParams(sde_variance=1.0))
|
||||
sd15 = StableDiffusion_1(device=test_device, solver=sde_solver)
|
||||
|
||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||
sd15.lda.load_from_safetensors(lda_weights)
|
||||
sd15.unet.load_from_safetensors(unet_weights_std)
|
||||
|
||||
return sd15
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sd15_std_float16(
|
||||
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
||||
|
@ -831,6 +854,33 @@ def test_diffusion_std_random_init(
|
|||
ensure_similar_images(predicted_image, expected_image_std_random_init)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_std_sde_random_init(
|
||||
sd15_std_sde: StableDiffusion_1, expected_image_std_sde_random_init: Image.Image, test_device: torch.device
|
||||
):
|
||||
sd15 = sd15_std_sde
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_inference_steps(50)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||
|
||||
for step in sd15.steps:
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_std_sde_random_init)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_batch2(sd15_std: StableDiffusion_1):
|
||||
sd15 = sd15_std
|
||||
|
|
|
@ -67,6 +67,35 @@ Special cases:
|
|||
|
||||
- `kitchen_dog.png` is generated with the same Diffusers script and negative prompt, seed 12, positive prompt "a small brown dog, detailed high-quality professional image, sitting on a chair, in a kitchen".
|
||||
|
||||
- `expected_std_sde_random_init.png` is generated with the following code:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
||||
|
||||
from refiners.fluxion.utils import manual_seed
|
||||
|
||||
diffusers_solver = DPMSolverMultistepScheduler.from_config( # type: ignore
|
||||
{
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"algorithm_type": "sde-dpmsolver++",
|
||||
"use_karras_sigmas": False,
|
||||
"final_sigmas_type": "sigma_min",
|
||||
"euler_at_final": True,
|
||||
}
|
||||
)
|
||||
model_id = "runwayml/stable-diffusion-v1-5"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, scheduler=diffusers_solver)
|
||||
pipe = pipe.to("cuda")
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
manual_seed(2)
|
||||
image = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=7.5).images[0]
|
||||
```
|
||||
|
||||
- `kitchen_mask.png` is made manually.
|
||||
|
||||
- Controlnet guides have been manually generated (x) using open source software and models, namely:
|
||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_std_sde_random_init.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_std_sde_random_init.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 349 KiB |
|
@ -59,6 +59,45 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
|
|||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
|
||||
def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool):
|
||||
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
|
||||
|
||||
manual_seed(0)
|
||||
|
||||
diffusers_scheduler = DiffuserScheduler(
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
lower_order_final=False,
|
||||
euler_at_final=last_step_first_order,
|
||||
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
|
||||
algorithm_type="sde-dpmsolver++",
|
||||
)
|
||||
diffusers_scheduler.set_timesteps(n_steps)
|
||||
solver = DPMSolver(
|
||||
num_inference_steps=n_steps,
|
||||
last_step_first_order=last_step_first_order,
|
||||
params=SolverParams(sde_variance=1.0),
|
||||
)
|
||||
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
|
||||
|
||||
sample = randn(1, 3, 32, 32)
|
||||
predicted_noise = randn(1, 3, 32, 32)
|
||||
|
||||
manual_seed(37)
|
||||
diffusers_outputs: list[Tensor] = [
|
||||
cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
|
||||
for timestep in diffusers_scheduler.timesteps
|
||||
]
|
||||
|
||||
manual_seed(37)
|
||||
refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)]
|
||||
|
||||
for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)):
|
||||
assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=1e-6), f"outputs differ at step {step}"
|
||||
|
||||
|
||||
def test_ddim_diffusers():
|
||||
from diffusers import DDIMScheduler # type: ignore
|
||||
|
||||
|
|
Loading…
Reference in a new issue