diff --git a/src/refiners/foundationals/latent_diffusion/solvers/ddim.py b/src/refiners/foundationals/latent_diffusion/solvers/ddim.py index 6b21980..31e64b9 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/ddim.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/ddim.py @@ -41,6 +41,8 @@ class DDIM(Solver): """ if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None): raise NotImplementedError + if params and params.sde_variance != 0.0: + raise NotImplementedError("DDIM does not support sde_variance != 0.0 yet") super().__init__( num_inference_steps=num_inference_steps, diff --git a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py index 2074715..c7b808b 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py @@ -2,7 +2,8 @@ import dataclasses from collections import deque import numpy as np -from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor +import torch +from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor from refiners.foundationals.latent_diffusion.solvers.solver import ( BaseSolverParams, @@ -51,6 +52,8 @@ class DPMSolver(Solver): """ if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None): raise NotImplementedError + if params and params.sde_variance not in (0.0, 1.0): + raise NotImplementedError("DPMSolver only supports sde_variance=0.0 or 1.0") super().__init__( num_inference_steps=num_inference_steps, @@ -93,7 +96,9 @@ class DPMSolver(Solver): np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:] return tensor(np_space).flip(0) - def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor: + def dpm_solver_first_order_update( + self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None + ) -> Tensor: """Applies a first-order backward Euler update to the input data `x`. Args: @@ -115,11 +120,21 @@ class DPMSolver(Solver): previous_noise_std = self.noise_std[previous_timestep] current_noise_std = self.noise_std[current_timestep] - factor = exp(-(previous_ratio - current_ratio)) - 1.0 - denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise - return denoised_x + ratio_delta = current_ratio - previous_ratio - def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor: + if sde_noise is None: + return (previous_noise_std / current_noise_std) * x + ( + 1.0 - torch.exp(ratio_delta) + ) * previous_scale_factor * noise + + factor = 1.0 - torch.exp(2.0 * ratio_delta) + return ( + (previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x + + previous_scale_factor * factor * noise + + previous_noise_std * torch.sqrt(factor) * sde_noise + ) + + def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noise: Tensor | None = None) -> Tensor: """Applies a second-order backward Euler update to the input data `x`. Args: @@ -147,13 +162,23 @@ class DPMSolver(Solver): estimation_delta = (current_data_estimation - next_data_estimation) / ( (current_ratio - next_ratio) / (previous_ratio - current_ratio) ) - factor = exp(-(previous_ratio - current_ratio)) - 1.0 - denoised_x = ( - (previous_noise_std / current_noise_std) * x - - (factor * previous_scale_factor) * current_data_estimation - - 0.5 * (factor * previous_scale_factor) * estimation_delta + ratio_delta = current_ratio - previous_ratio + + if sde_noise is None: + factor = 1.0 - torch.exp(ratio_delta) + return ( + (previous_noise_std / current_noise_std) * x + + previous_scale_factor * factor * current_data_estimation + + 0.5 * previous_scale_factor * factor * estimation_delta + ) + + factor = 1.0 - torch.exp(2.0 * ratio_delta) + return ( + (previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x + + previous_scale_factor * factor * current_data_estimation + + 0.5 * previous_scale_factor * factor * estimation_delta + + previous_noise_std * torch.sqrt(factor) * sde_noise ) - return denoised_x def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: """Apply one step of the backward diffusion process. @@ -175,11 +200,20 @@ class DPMSolver(Solver): assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" current_timestep = self.timesteps[step] - scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep] + scale_factor = self.cumulative_scale_factors[current_timestep] + noise_ratio = self.noise_std[current_timestep] estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor self.estimated_data.append(estimated_denoised_data) + variance = self.params.sde_variance + sde_noise = ( + torch.randn(x.shape, generator=generator, device=x.device, dtype=x.dtype) * variance + if variance > 0.0 + else None + ) if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1): - return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step) + return self.dpm_solver_first_order_update( + x=x, noise=estimated_denoised_data, step=step, sde_noise=sde_noise + ) - return self.multistep_dpm_solver_second_order_update(x=x, step=step) + return self.multistep_dpm_solver_second_order_update(x=x, step=step, sde_noise=sde_noise) diff --git a/src/refiners/foundationals/latent_diffusion/solvers/euler.py b/src/refiners/foundationals/latent_diffusion/solvers/euler.py index 09f8007..69e9da2 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/euler.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/euler.py @@ -36,6 +36,8 @@ class Euler(Solver): """ if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None): raise NotImplementedError + if params and params.sde_variance != 0.0: + raise NotImplementedError("Euler does not support sde_variance != 0.0 yet") super().__init__( num_inference_steps=num_inference_steps, diff --git a/src/refiners/foundationals/latent_diffusion/solvers/solver.py b/src/refiners/foundationals/latent_diffusion/solvers/solver.py index 1572091..6ff2be9 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/solver.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/solver.py @@ -79,6 +79,7 @@ class BaseSolverParams: final_diffusion_rate: float | None noise_schedule: NoiseSchedule | None model_prediction_type: ModelPredictionType | None + sde_variance: float @dataclasses.dataclass(kw_only=True, frozen=True) @@ -102,6 +103,7 @@ class SolverParams(BaseSolverParams): final_diffusion_rate: float | None = None noise_schedule: NoiseSchedule | None = None model_prediction_type: ModelPredictionType | None = None + sde_variance: float = 0.0 @dataclasses.dataclass(kw_only=True, frozen=True) @@ -113,6 +115,7 @@ class ResolvedSolverParams(BaseSolverParams): final_diffusion_rate: float noise_schedule: NoiseSchedule model_prediction_type: ModelPredictionType + sde_variance: float class Solver(fl.Module, ABC): @@ -123,6 +126,19 @@ class Solver(fl.Module, ABC): This process is described using several parameters such as initial and final diffusion rates, and is encapsulated into a `__call__` method that applies a step of the diffusion process. + + Attributes: + params: The common parameters for solvers. See `SolverParams`. + num_inference_steps: The number of inference steps to perform. + first_inference_step: The step to start the inference process from. + scale_factors: The scale factors used to denoise the input. These are called "betas" in other implementations, + and `1 - scale_factors` is called "alphas". + cumulative_scale_factors: The cumulative scale factors used to denoise the input. These are called "alpha_t" in + other implementations. + noise_std: The standard deviation of the noise used to denoise the input. This is called "sigma_t" in other + implementations. + signal_to_noise_ratios: The signal-to-noise ratios used to denoise the input. This is called "lambda_t" in other + implementations. """ timesteps: Tensor @@ -136,6 +152,7 @@ class Solver(fl.Module, ABC): final_diffusion_rate=1.2e-2, noise_schedule=NoiseSchedule.QUADRATIC, model_prediction_type=ModelPredictionType.NOISE, + sde_variance=0.0, ) def __init__( diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 1b1279e..d4e4865 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -88,6 +88,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_std_random_init.png").convert("RGB") +@pytest.fixture +def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image: + return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB") + + @pytest.fixture def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB") @@ -560,6 +565,24 @@ def sd15_std( return sd15 +@pytest.fixture +def sd15_std_sde( + text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device +) -> StableDiffusion_1: + if test_device.type == "cpu": + warn("not running on CPU, skipping") + pytest.skip() + + sde_solver = DPMSolver(num_inference_steps=30, last_step_first_order=True, params=SolverParams(sde_variance=1.0)) + sd15 = StableDiffusion_1(device=test_device, solver=sde_solver) + + sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) + sd15.lda.load_from_safetensors(lda_weights) + sd15.unet.load_from_safetensors(unet_weights_std) + + return sd15 + + @pytest.fixture def sd15_std_float16( text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device @@ -831,6 +854,33 @@ def test_diffusion_std_random_init( ensure_similar_images(predicted_image, expected_image_std_random_init) +@no_grad() +def test_diffusion_std_sde_random_init( + sd15_std_sde: StableDiffusion_1, expected_image_std_sde_random_init: Image.Image, test_device: torch.device +): + sd15 = sd15_std_sde + + prompt = "a cute cat, detailed high-quality professional image" + negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + + sd15.set_inference_steps(50) + + manual_seed(2) + x = torch.randn(1, 4, 64, 64, device=test_device) + + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.latents_to_image(x) + + ensure_similar_images(predicted_image, expected_image_std_sde_random_init) + + @no_grad() def test_diffusion_batch2(sd15_std: StableDiffusion_1): sd15 = sd15_std diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index 9e7075a..926127f 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -67,6 +67,35 @@ Special cases: - `kitchen_dog.png` is generated with the same Diffusers script and negative prompt, seed 12, positive prompt "a small brown dog, detailed high-quality professional image, sitting on a chair, in a kitchen". +- `expected_std_sde_random_init.png` is generated with the following code: + +```python +import torch +from diffusers import StableDiffusionPipeline +from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler + +from refiners.fluxion.utils import manual_seed + +diffusers_solver = DPMSolverMultistepScheduler.from_config( # type: ignore + { + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "algorithm_type": "sde-dpmsolver++", + "use_karras_sigmas": False, + "final_sigmas_type": "sigma_min", + "euler_at_final": True, + } +) +model_id = "runwayml/stable-diffusion-v1-5" +pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, scheduler=diffusers_solver) +pipe = pipe.to("cuda") +prompt = "a cute cat, detailed high-quality professional image" +negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" +manual_seed(2) +image = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=7.5).images[0] +``` + - `kitchen_mask.png` is made manually. - Controlnet guides have been manually generated (x) using open source software and models, namely: diff --git a/tests/e2e/test_diffusion_ref/expected_std_sde_random_init.png b/tests/e2e/test_diffusion_ref/expected_std_sde_random_init.png new file mode 100644 index 0000000..954a4e4 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_std_sde_random_init.png differ diff --git a/tests/foundationals/latent_diffusion/test_solvers.py b/tests/foundationals/latent_diffusion/test_solvers.py index f9888e1..5522027 100644 --- a/tests/foundationals/latent_diffusion/test_solvers.py +++ b/tests/foundationals/latent_diffusion/test_solvers.py @@ -59,6 +59,45 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool): assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}" +@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)]) +def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool): + from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore + + manual_seed(0) + + diffusers_scheduler = DiffuserScheduler( + beta_schedule="scaled_linear", + beta_start=0.00085, + beta_end=0.012, + lower_order_final=False, + euler_at_final=last_step_first_order, + final_sigmas_type="sigma_min", # default before Diffusers 0.26.0 + algorithm_type="sde-dpmsolver++", + ) + diffusers_scheduler.set_timesteps(n_steps) + solver = DPMSolver( + num_inference_steps=n_steps, + last_step_first_order=last_step_first_order, + params=SolverParams(sde_variance=1.0), + ) + assert equal(solver.timesteps, diffusers_scheduler.timesteps) + + sample = randn(1, 3, 32, 32) + predicted_noise = randn(1, 3, 32, 32) + + manual_seed(37) + diffusers_outputs: list[Tensor] = [ + cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore + for timestep in diffusers_scheduler.timesteps + ] + + manual_seed(37) + refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)] + + for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)): + assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=1e-6), f"outputs differ at step {step}" + + def test_ddim_diffusers(): from diffusers import DDIMScheduler # type: ignore