mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
Add stochastic sampling to DPM solver (SDE)
This commit is contained in:
parent
daee77298d
commit
09a9dfd494
|
@ -41,6 +41,8 @@ class DDIM(Solver):
|
||||||
"""
|
"""
|
||||||
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
|
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
if params and params.sde_variance != 0.0:
|
||||||
|
raise NotImplementedError("DDIM does not support sde_variance != 0.0 yet")
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
|
|
|
@ -2,7 +2,8 @@ import dataclasses
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
import torch
|
||||||
|
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
BaseSolverParams,
|
BaseSolverParams,
|
||||||
|
@ -51,6 +52,8 @@ class DPMSolver(Solver):
|
||||||
"""
|
"""
|
||||||
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
|
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
if params and params.sde_variance not in (0.0, 1.0):
|
||||||
|
raise NotImplementedError("DPMSolver only supports sde_variance=0.0 or 1.0")
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
|
@ -93,7 +96,9 @@ class DPMSolver(Solver):
|
||||||
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
|
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
|
||||||
return tensor(np_space).flip(0)
|
return tensor(np_space).flip(0)
|
||||||
|
|
||||||
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
def dpm_solver_first_order_update(
|
||||||
|
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
|
||||||
|
) -> Tensor:
|
||||||
"""Applies a first-order backward Euler update to the input data `x`.
|
"""Applies a first-order backward Euler update to the input data `x`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -115,11 +120,21 @@ class DPMSolver(Solver):
|
||||||
previous_noise_std = self.noise_std[previous_timestep]
|
previous_noise_std = self.noise_std[previous_timestep]
|
||||||
current_noise_std = self.noise_std[current_timestep]
|
current_noise_std = self.noise_std[current_timestep]
|
||||||
|
|
||||||
factor = exp(-(previous_ratio - current_ratio)) - 1.0
|
ratio_delta = current_ratio - previous_ratio
|
||||||
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
|
|
||||||
return denoised_x
|
|
||||||
|
|
||||||
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
|
if sde_noise is None:
|
||||||
|
return (previous_noise_std / current_noise_std) * x + (
|
||||||
|
1.0 - torch.exp(ratio_delta)
|
||||||
|
) * previous_scale_factor * noise
|
||||||
|
|
||||||
|
factor = 1.0 - torch.exp(2.0 * ratio_delta)
|
||||||
|
return (
|
||||||
|
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
|
||||||
|
+ previous_scale_factor * factor * noise
|
||||||
|
+ previous_noise_std * torch.sqrt(factor) * sde_noise
|
||||||
|
)
|
||||||
|
|
||||||
|
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noise: Tensor | None = None) -> Tensor:
|
||||||
"""Applies a second-order backward Euler update to the input data `x`.
|
"""Applies a second-order backward Euler update to the input data `x`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -147,13 +162,23 @@ class DPMSolver(Solver):
|
||||||
estimation_delta = (current_data_estimation - next_data_estimation) / (
|
estimation_delta = (current_data_estimation - next_data_estimation) / (
|
||||||
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
|
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
|
||||||
)
|
)
|
||||||
factor = exp(-(previous_ratio - current_ratio)) - 1.0
|
ratio_delta = current_ratio - previous_ratio
|
||||||
denoised_x = (
|
|
||||||
|
if sde_noise is None:
|
||||||
|
factor = 1.0 - torch.exp(ratio_delta)
|
||||||
|
return (
|
||||||
(previous_noise_std / current_noise_std) * x
|
(previous_noise_std / current_noise_std) * x
|
||||||
- (factor * previous_scale_factor) * current_data_estimation
|
+ previous_scale_factor * factor * current_data_estimation
|
||||||
- 0.5 * (factor * previous_scale_factor) * estimation_delta
|
+ 0.5 * previous_scale_factor * factor * estimation_delta
|
||||||
|
)
|
||||||
|
|
||||||
|
factor = 1.0 - torch.exp(2.0 * ratio_delta)
|
||||||
|
return (
|
||||||
|
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
|
||||||
|
+ previous_scale_factor * factor * current_data_estimation
|
||||||
|
+ 0.5 * previous_scale_factor * factor * estimation_delta
|
||||||
|
+ previous_noise_std * torch.sqrt(factor) * sde_noise
|
||||||
)
|
)
|
||||||
return denoised_x
|
|
||||||
|
|
||||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
"""Apply one step of the backward diffusion process.
|
"""Apply one step of the backward diffusion process.
|
||||||
|
@ -175,11 +200,20 @@ class DPMSolver(Solver):
|
||||||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||||
|
|
||||||
current_timestep = self.timesteps[step]
|
current_timestep = self.timesteps[step]
|
||||||
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
|
scale_factor = self.cumulative_scale_factors[current_timestep]
|
||||||
|
noise_ratio = self.noise_std[current_timestep]
|
||||||
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
|
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
|
||||||
self.estimated_data.append(estimated_denoised_data)
|
self.estimated_data.append(estimated_denoised_data)
|
||||||
|
variance = self.params.sde_variance
|
||||||
|
sde_noise = (
|
||||||
|
torch.randn(x.shape, generator=generator, device=x.device, dtype=x.dtype) * variance
|
||||||
|
if variance > 0.0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1):
|
if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1):
|
||||||
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
|
return self.dpm_solver_first_order_update(
|
||||||
|
x=x, noise=estimated_denoised_data, step=step, sde_noise=sde_noise
|
||||||
|
)
|
||||||
|
|
||||||
return self.multistep_dpm_solver_second_order_update(x=x, step=step)
|
return self.multistep_dpm_solver_second_order_update(x=x, step=step, sde_noise=sde_noise)
|
||||||
|
|
|
@ -36,6 +36,8 @@ class Euler(Solver):
|
||||||
"""
|
"""
|
||||||
if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None):
|
if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
if params and params.sde_variance != 0.0:
|
||||||
|
raise NotImplementedError("Euler does not support sde_variance != 0.0 yet")
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
|
|
|
@ -79,6 +79,7 @@ class BaseSolverParams:
|
||||||
final_diffusion_rate: float | None
|
final_diffusion_rate: float | None
|
||||||
noise_schedule: NoiseSchedule | None
|
noise_schedule: NoiseSchedule | None
|
||||||
model_prediction_type: ModelPredictionType | None
|
model_prediction_type: ModelPredictionType | None
|
||||||
|
sde_variance: float
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||||
|
@ -102,6 +103,7 @@ class SolverParams(BaseSolverParams):
|
||||||
final_diffusion_rate: float | None = None
|
final_diffusion_rate: float | None = None
|
||||||
noise_schedule: NoiseSchedule | None = None
|
noise_schedule: NoiseSchedule | None = None
|
||||||
model_prediction_type: ModelPredictionType | None = None
|
model_prediction_type: ModelPredictionType | None = None
|
||||||
|
sde_variance: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||||
|
@ -113,6 +115,7 @@ class ResolvedSolverParams(BaseSolverParams):
|
||||||
final_diffusion_rate: float
|
final_diffusion_rate: float
|
||||||
noise_schedule: NoiseSchedule
|
noise_schedule: NoiseSchedule
|
||||||
model_prediction_type: ModelPredictionType
|
model_prediction_type: ModelPredictionType
|
||||||
|
sde_variance: float
|
||||||
|
|
||||||
|
|
||||||
class Solver(fl.Module, ABC):
|
class Solver(fl.Module, ABC):
|
||||||
|
@ -123,6 +126,19 @@ class Solver(fl.Module, ABC):
|
||||||
|
|
||||||
This process is described using several parameters such as initial and final diffusion rates,
|
This process is described using several parameters such as initial and final diffusion rates,
|
||||||
and is encapsulated into a `__call__` method that applies a step of the diffusion process.
|
and is encapsulated into a `__call__` method that applies a step of the diffusion process.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
params: The common parameters for solvers. See `SolverParams`.
|
||||||
|
num_inference_steps: The number of inference steps to perform.
|
||||||
|
first_inference_step: The step to start the inference process from.
|
||||||
|
scale_factors: The scale factors used to denoise the input. These are called "betas" in other implementations,
|
||||||
|
and `1 - scale_factors` is called "alphas".
|
||||||
|
cumulative_scale_factors: The cumulative scale factors used to denoise the input. These are called "alpha_t" in
|
||||||
|
other implementations.
|
||||||
|
noise_std: The standard deviation of the noise used to denoise the input. This is called "sigma_t" in other
|
||||||
|
implementations.
|
||||||
|
signal_to_noise_ratios: The signal-to-noise ratios used to denoise the input. This is called "lambda_t" in other
|
||||||
|
implementations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
timesteps: Tensor
|
timesteps: Tensor
|
||||||
|
@ -136,6 +152,7 @@ class Solver(fl.Module, ABC):
|
||||||
final_diffusion_rate=1.2e-2,
|
final_diffusion_rate=1.2e-2,
|
||||||
noise_schedule=NoiseSchedule.QUADRATIC,
|
noise_schedule=NoiseSchedule.QUADRATIC,
|
||||||
model_prediction_type=ModelPredictionType.NOISE,
|
model_prediction_type=ModelPredictionType.NOISE,
|
||||||
|
sde_variance=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -88,6 +88,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image:
|
||||||
return _img_open(ref_path / "expected_std_random_init.png").convert("RGB")
|
return _img_open(ref_path / "expected_std_random_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image:
|
||||||
|
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
|
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
|
||||||
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
|
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
|
||||||
|
@ -560,6 +565,24 @@ def sd15_std(
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sd15_std_sde(
|
||||||
|
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
||||||
|
) -> StableDiffusion_1:
|
||||||
|
if test_device.type == "cpu":
|
||||||
|
warn("not running on CPU, skipping")
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
|
sde_solver = DPMSolver(num_inference_steps=30, last_step_first_order=True, params=SolverParams(sde_variance=1.0))
|
||||||
|
sd15 = StableDiffusion_1(device=test_device, solver=sde_solver)
|
||||||
|
|
||||||
|
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||||
|
sd15.lda.load_from_safetensors(lda_weights)
|
||||||
|
sd15.unet.load_from_safetensors(unet_weights_std)
|
||||||
|
|
||||||
|
return sd15
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sd15_std_float16(
|
def sd15_std_float16(
|
||||||
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
|
||||||
|
@ -831,6 +854,33 @@ def test_diffusion_std_random_init(
|
||||||
ensure_similar_images(predicted_image, expected_image_std_random_init)
|
ensure_similar_images(predicted_image, expected_image_std_random_init)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_diffusion_std_sde_random_init(
|
||||||
|
sd15_std_sde: StableDiffusion_1, expected_image_std_sde_random_init: Image.Image, test_device: torch.device
|
||||||
|
):
|
||||||
|
sd15 = sd15_std_sde
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
|
|
||||||
|
sd15.set_inference_steps(50)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
|
for step in sd15.steps:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.latents_to_image(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image_std_sde_random_init)
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_batch2(sd15_std: StableDiffusion_1):
|
def test_diffusion_batch2(sd15_std: StableDiffusion_1):
|
||||||
sd15 = sd15_std
|
sd15 = sd15_std
|
||||||
|
|
|
@ -67,6 +67,35 @@ Special cases:
|
||||||
|
|
||||||
- `kitchen_dog.png` is generated with the same Diffusers script and negative prompt, seed 12, positive prompt "a small brown dog, detailed high-quality professional image, sitting on a chair, in a kitchen".
|
- `kitchen_dog.png` is generated with the same Diffusers script and negative prompt, seed 12, positive prompt "a small brown dog, detailed high-quality professional image, sitting on a chair, in a kitchen".
|
||||||
|
|
||||||
|
- `expected_std_sde_random_init.png` is generated with the following code:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import manual_seed
|
||||||
|
|
||||||
|
diffusers_solver = DPMSolverMultistepScheduler.from_config( # type: ignore
|
||||||
|
{
|
||||||
|
"beta_end": 0.012,
|
||||||
|
"beta_schedule": "scaled_linear",
|
||||||
|
"beta_start": 0.00085,
|
||||||
|
"algorithm_type": "sde-dpmsolver++",
|
||||||
|
"use_karras_sigmas": False,
|
||||||
|
"final_sigmas_type": "sigma_min",
|
||||||
|
"euler_at_final": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
model_id = "runwayml/stable-diffusion-v1-5"
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, scheduler=diffusers_solver)
|
||||||
|
pipe = pipe.to("cuda")
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
manual_seed(2)
|
||||||
|
image = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=7.5).images[0]
|
||||||
|
```
|
||||||
|
|
||||||
- `kitchen_mask.png` is made manually.
|
- `kitchen_mask.png` is made manually.
|
||||||
|
|
||||||
- Controlnet guides have been manually generated (x) using open source software and models, namely:
|
- Controlnet guides have been manually generated (x) using open source software and models, namely:
|
||||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_std_sde_random_init.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_std_sde_random_init.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 349 KiB |
|
@ -59,6 +59,45 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
|
||||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
|
||||||
|
def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool):
|
||||||
|
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
|
||||||
|
|
||||||
|
manual_seed(0)
|
||||||
|
|
||||||
|
diffusers_scheduler = DiffuserScheduler(
|
||||||
|
beta_schedule="scaled_linear",
|
||||||
|
beta_start=0.00085,
|
||||||
|
beta_end=0.012,
|
||||||
|
lower_order_final=False,
|
||||||
|
euler_at_final=last_step_first_order,
|
||||||
|
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
|
||||||
|
algorithm_type="sde-dpmsolver++",
|
||||||
|
)
|
||||||
|
diffusers_scheduler.set_timesteps(n_steps)
|
||||||
|
solver = DPMSolver(
|
||||||
|
num_inference_steps=n_steps,
|
||||||
|
last_step_first_order=last_step_first_order,
|
||||||
|
params=SolverParams(sde_variance=1.0),
|
||||||
|
)
|
||||||
|
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
|
||||||
|
|
||||||
|
sample = randn(1, 3, 32, 32)
|
||||||
|
predicted_noise = randn(1, 3, 32, 32)
|
||||||
|
|
||||||
|
manual_seed(37)
|
||||||
|
diffusers_outputs: list[Tensor] = [
|
||||||
|
cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
|
||||||
|
for timestep in diffusers_scheduler.timesteps
|
||||||
|
]
|
||||||
|
|
||||||
|
manual_seed(37)
|
||||||
|
refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)]
|
||||||
|
|
||||||
|
for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)):
|
||||||
|
assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=1e-6), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
def test_ddim_diffusers():
|
def test_ddim_diffusers():
|
||||||
from diffusers import DDIMScheduler # type: ignore
|
from diffusers import DDIMScheduler # type: ignore
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue