refiners/tests/foundationals/latent_diffusion/test_solvers.py
2024-09-06 15:30:02 +02:00

311 lines
11 KiB
Python

import itertools
from typing import cast
from warnings import warn
import pytest
from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn, tensor
from refiners.fluxion import manual_seed
from refiners.foundationals.latent_diffusion.solvers import (
DDIM,
DDPM,
DPMSolver,
Euler,
FrankenSolver,
LCMSolver,
ModelPredictionType,
NoiseSchedule,
Solver,
SolverParams,
TimestepSpacing,
)
def test_ddpm_diffusers():
from diffusers import DDPMScheduler # type: ignore
diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
diffusers_scheduler.set_timesteps(1000)
solver = DDPM(num_inference_steps=1000)
assert equal(diffusers_scheduler.timesteps, solver.timesteps)
@pytest.mark.parametrize(
"n_steps, last_step_first_order, sde_variance, use_karras_sigmas",
list(itertools.product([5, 30], [False, True], [0.0, 1.0], [False, True])),
)
def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool, sde_variance: float, use_karras_sigmas: bool):
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
manual_seed(0)
diffusers_scheduler = DiffuserScheduler(
beta_schedule="scaled_linear",
beta_start=0.00085,
beta_end=0.012,
lower_order_final=False,
euler_at_final=last_step_first_order,
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
algorithm_type="sde-dpmsolver++" if sde_variance == 1.0 else "dpmsolver++",
use_karras_sigmas=use_karras_sigmas,
)
diffusers_scheduler.set_timesteps(n_steps)
solver = DPMSolver(
num_inference_steps=n_steps,
last_step_first_order=last_step_first_order,
params=SolverParams(
sde_variance=sde_variance,
sigma_schedule=NoiseSchedule.KARRAS if use_karras_sigmas else None,
),
)
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 3, 32, 32)
predicted_noise = randn(1, 3, 32, 32)
manual_seed(37)
diffusers_outputs: list[Tensor] = [
cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
for timestep in diffusers_scheduler.timesteps
]
manual_seed(37)
refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)]
atol = 1e-4 if use_karras_sigmas else 1e-6
for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)):
assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=atol), f"outputs differ at step {step}"
def test_ddim_diffusers():
from diffusers import DDIMScheduler # type: ignore
manual_seed(0)
diffusers_scheduler = DDIMScheduler(
beta_end=0.012,
beta_schedule="scaled_linear",
beta_start=0.00085,
num_train_timesteps=1000,
steps_offset=1,
clip_sample=False,
)
diffusers_scheduler.set_timesteps(30)
solver = DDIM(num_inference_steps=30)
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32)
for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
refiners_output = solver(x=sample, predicted_noise=predicted_noise, step=step)
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
@pytest.mark.parametrize("model_prediction_type", [ModelPredictionType.NOISE, ModelPredictionType.SAMPLE])
def test_euler_diffusers(model_prediction_type: ModelPredictionType):
from diffusers import EulerDiscreteScheduler # type: ignore
manual_seed(0)
diffusers_prediction_type = "epsilon" if model_prediction_type == ModelPredictionType.NOISE else "sample"
diffusers_scheduler = EulerDiscreteScheduler(
beta_end=0.012,
beta_schedule="scaled_linear",
beta_start=0.00085,
num_train_timesteps=1000,
steps_offset=1,
timestep_spacing="linspace",
use_karras_sigmas=False,
prediction_type=diffusers_prediction_type,
)
diffusers_scheduler.set_timesteps(30)
solver = Euler(num_inference_steps=30, params=SolverParams(model_prediction_type=model_prediction_type))
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32)
ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore
assert isinstance(ref_init_noise_sigma, Tensor)
assert isclose(ref_init_noise_sigma, solver.init_noise_sigma), "init_noise_sigma differ"
for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
refiners_output = solver(x=sample, predicted_noise=predicted_noise, step=step)
assert allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}"
def test_franken_diffusers():
from diffusers import EulerDiscreteScheduler # type: ignore
manual_seed(0)
params = {
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"num_train_timesteps": 1000,
"steps_offset": 1,
"timestep_spacing": "linspace",
"use_karras_sigmas": False,
}
diffusers_scheduler = EulerDiscreteScheduler(**params) # type: ignore
diffusers_scheduler.set_timesteps(30)
diffusers_scheduler_2 = EulerDiscreteScheduler(**params) # type: ignore
solver = FrankenSolver(lambda: diffusers_scheduler_2, num_inference_steps=30)
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32)
ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore
assert isinstance(ref_init_noise_sigma, Tensor)
init_noise_sigma = solver.scale_model_input(tensor(1), step=-1)
assert equal(ref_init_noise_sigma, init_noise_sigma), "init_noise_sigma differ"
for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
refiners_output = solver(x=sample, predicted_noise=predicted_noise, step=step)
assert equal(diffusers_output, refiners_output), f"outputs differ at step {step}"
def test_lcm_diffusers():
from diffusers import LCMScheduler # type: ignore
manual_seed(0)
# LCMScheduler is stochastic, make sure we use identical generators
diffusers_generator = Generator().manual_seed(42)
refiners_generator = Generator().manual_seed(42)
diffusers_scheduler = LCMScheduler()
diffusers_scheduler.set_timesteps(4)
solver = LCMSolver(num_inference_steps=4)
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
sample = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32)
for step, timestep in enumerate(diffusers_scheduler.timesteps):
alpha_prod_t = diffusers_scheduler.alphas_cumprod[timestep]
diffusers_noise_ratio = (1 - alpha_prod_t).sqrt()
diffusers_scale_factor = alpha_prod_t.sqrt()
refiners_scale_factor = solver.cumulative_scale_factors[timestep]
refiners_noise_ratio = solver.noise_std[timestep]
assert refiners_scale_factor == diffusers_scale_factor
assert refiners_noise_ratio == diffusers_noise_ratio
d_out = diffusers_scheduler.step(predicted_noise, timestep, sample, generator=diffusers_generator) # type: ignore
diffusers_output = cast(Tensor, d_out.prev_sample) # type: ignore
refiners_output = solver(
x=sample,
predicted_noise=predicted_noise,
step=step,
generator=refiners_generator,
)
assert allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}"
def test_solver_remove_noise():
from diffusers import DDIMScheduler # type: ignore
manual_seed(0)
diffusers_scheduler = DDIMScheduler(
beta_end=0.012,
beta_schedule="scaled_linear",
beta_start=0.00085,
num_train_timesteps=1000,
steps_offset=1,
clip_sample=False,
)
diffusers_scheduler.set_timesteps(30)
solver = DDIM(num_inference_steps=30)
sample = randn(1, 4, 32, 32)
noise = randn(1, 4, 32, 32)
for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore
refiners_output = solver.remove_noise(x=sample, noise=noise, step=step)
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
def test_solver_device(test_device: Device):
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
scheduler = DDIM(num_inference_steps=30, device=test_device)
x = randn(1, 4, 32, 32, device=test_device)
noise = randn(1, 4, 32, 32, device=test_device)
noised = scheduler.add_noise(x, noise, scheduler.first_inference_step)
assert noised.device == test_device
def test_solver_add_noise(test_device: Device):
scheduler = DDIM(num_inference_steps=30, device=test_device)
latent = randn(1, 4, 32, 32, device=test_device)
noise = randn(1, 4, 32, 32, device=test_device)
noised = scheduler.add_noise(
x=latent,
noise=noise,
step=0,
)
noised_double = scheduler.add_noise(
x=latent.repeat(2, 1, 1, 1),
noise=noise.repeat(2, 1, 1, 1),
step=[0, 0],
)
assert allclose(noised, noised_double[0])
assert allclose(noised, noised_double[1])
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
scheduler = DDIM(
num_inference_steps=30,
params=SolverParams(noise_schedule=noise_schedule),
device=test_device,
)
assert len(scheduler.scale_factors) == 1000
assert scheduler.scale_factors[0] == 1 - scheduler.params.initial_diffusion_rate
assert scheduler.scale_factors[-1] == 1 - scheduler.params.final_diffusion_rate
def test_solver_timestep_spacing():
# Tests we get the results from [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
linspace_int = Solver.generate_timesteps(
spacing=TimestepSpacing.LINSPACE_ROUNDED,
num_inference_steps=10,
num_train_timesteps=1000,
offset=1,
)
assert equal(linspace_int, tensor([1000, 889, 778, 667, 556, 445, 334, 223, 112, 1]))
leading = Solver.generate_timesteps(
spacing=TimestepSpacing.LEADING,
num_inference_steps=10,
num_train_timesteps=1000,
offset=1,
)
assert equal(leading, tensor([901, 801, 701, 601, 501, 401, 301, 201, 101, 1]))
trailing = Solver.generate_timesteps(
spacing=TimestepSpacing.TRAILING,
num_inference_steps=10,
num_train_timesteps=1000,
offset=1,
)
assert equal(trailing, tensor([1000, 900, 800, 700, 600, 500, 400, 300, 200, 100]))