2024-09-06 10:56:24 +00:00
import itertools
2023-08-04 13:28:41 +00:00
from typing import cast
2023-09-21 09:47:11 +00:00
from warnings import warn
2023-12-11 10:46:38 +00:00
import pytest
2024-09-25 10:44:35 +00:00
import torch
from torch import Tensor , device as Device
2023-12-11 10:46:38 +00:00
2023-12-04 14:08:34 +00:00
from refiners . fluxion import manual_seed
2024-02-22 11:02:58 +00:00
from refiners . foundationals . latent_diffusion . solvers import (
DDIM ,
DDPM ,
DPMSolver ,
Euler ,
2024-07-10 14:08:05 +00:00
FrankenSolver ,
2024-02-22 11:02:58 +00:00
LCMSolver ,
2024-02-22 14:16:22 +00:00
ModelPredictionType ,
2024-02-22 11:02:58 +00:00
NoiseSchedule ,
Solver ,
2024-02-22 14:16:22 +00:00
SolverParams ,
2024-02-22 11:02:58 +00:00
TimestepSpacing ,
)
2023-12-13 10:43:11 +00:00
def test_ddpm_diffusers ( ) :
from diffusers import DDPMScheduler # type: ignore
diffusers_scheduler = DDPMScheduler ( beta_schedule = " scaled_linear " , beta_start = 0.00085 , beta_end = 0.012 )
diffusers_scheduler . set_timesteps ( 1000 )
2024-07-10 14:09:45 +00:00
solver = DDPM ( num_inference_steps = 1000 )
2024-09-25 10:44:35 +00:00
assert torch . equal ( diffusers_scheduler . timesteps , solver . timesteps )
2023-08-04 13:28:41 +00:00
2024-09-06 10:56:24 +00:00
@pytest.mark.parametrize (
" n_steps, last_step_first_order, sde_variance, use_karras_sigmas " ,
list ( itertools . product ( [ 5 , 30 ] , [ False , True ] , [ 0.0 , 1.0 ] , [ False , True ] ) ) ,
)
def test_dpm_solver_diffusers ( n_steps : int , last_step_first_order : bool , sde_variance : float , use_karras_sigmas : bool ) :
2024-07-23 08:52:40 +00:00
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
manual_seed ( 0 )
diffusers_scheduler = DiffuserScheduler (
beta_schedule = " scaled_linear " ,
beta_start = 0.00085 ,
beta_end = 0.012 ,
lower_order_final = False ,
euler_at_final = last_step_first_order ,
final_sigmas_type = " sigma_min " , # default before Diffusers 0.26.0
2024-09-06 10:56:24 +00:00
algorithm_type = " sde-dpmsolver++ " if sde_variance == 1.0 else " dpmsolver++ " ,
use_karras_sigmas = use_karras_sigmas ,
2024-07-23 08:52:40 +00:00
)
diffusers_scheduler . set_timesteps ( n_steps )
solver = DPMSolver (
num_inference_steps = n_steps ,
last_step_first_order = last_step_first_order ,
2024-09-06 10:56:24 +00:00
params = SolverParams (
sde_variance = sde_variance ,
sigma_schedule = NoiseSchedule . KARRAS if use_karras_sigmas else None ,
) ,
2024-07-23 08:52:40 +00:00
)
2024-09-25 10:44:35 +00:00
assert torch . equal ( solver . timesteps , diffusers_scheduler . timesteps )
2024-07-23 08:52:40 +00:00
2024-09-25 10:44:35 +00:00
sample = torch . randn ( 1 , 3 , 32 , 32 )
predicted_noise = torch . randn ( 1 , 3 , 32 , 32 )
2024-07-23 08:52:40 +00:00
manual_seed ( 37 )
diffusers_outputs : list [ Tensor ] = [
cast ( Tensor , diffusers_scheduler . step ( predicted_noise , timestep , sample ) . prev_sample ) # type: ignore
for timestep in diffusers_scheduler . timesteps
]
manual_seed ( 37 )
refiners_outputs = [ solver ( x = sample , predicted_noise = predicted_noise , step = step ) for step in range ( n_steps ) ]
2024-09-25 19:11:38 +00:00
if use_karras_sigmas :
atol = 1e-4
elif sde_variance == 1.0 :
atol = 1e-6
else :
atol = 1e-8
2024-07-23 08:52:40 +00:00
for step , ( diffusers_output , refiners_output ) in enumerate ( zip ( diffusers_outputs , refiners_outputs ) ) :
2024-09-25 10:44:35 +00:00
assert torch . allclose ( diffusers_output , refiners_output , rtol = 0.01 , atol = atol ) , f " outputs differ at step { step } "
2024-07-23 08:52:40 +00:00
2023-12-12 16:18:29 +00:00
def test_ddim_diffusers ( ) :
2023-08-04 13:28:41 +00:00
from diffusers import DDIMScheduler # type: ignore
2023-12-12 16:18:29 +00:00
manual_seed ( 0 )
2023-08-04 13:28:41 +00:00
diffusers_scheduler = DDIMScheduler (
beta_end = 0.012 ,
beta_schedule = " scaled_linear " ,
beta_start = 0.00085 ,
num_train_timesteps = 1000 ,
steps_offset = 1 ,
clip_sample = False ,
)
diffusers_scheduler . set_timesteps ( 30 )
2024-07-10 14:09:45 +00:00
solver = DDIM ( num_inference_steps = 30 )
2024-09-25 10:44:35 +00:00
assert torch . equal ( solver . timesteps , diffusers_scheduler . timesteps )
2023-08-04 13:28:41 +00:00
2024-09-25 10:44:35 +00:00
sample = torch . randn ( 1 , 4 , 32 , 32 )
predicted_noise = torch . randn ( 1 , 4 , 32 , 32 )
2023-08-04 13:28:41 +00:00
for step , timestep in enumerate ( diffusers_scheduler . timesteps ) :
2024-01-20 17:37:49 +00:00
diffusers_output = cast ( Tensor , diffusers_scheduler . step ( predicted_noise , timestep , sample ) . prev_sample ) # type: ignore
2024-07-10 14:09:45 +00:00
refiners_output = solver ( x = sample , predicted_noise = predicted_noise , step = step )
2023-08-04 13:28:41 +00:00
2024-09-25 10:44:35 +00:00
assert torch . allclose ( diffusers_output , refiners_output , rtol = 0.01 ) , f " outputs differ at step { step } "
2023-09-21 09:47:11 +00:00
2024-02-22 14:16:22 +00:00
@pytest.mark.parametrize ( " model_prediction_type " , [ ModelPredictionType . NOISE , ModelPredictionType . SAMPLE ] )
def test_euler_diffusers ( model_prediction_type : ModelPredictionType ) :
2024-01-10 10:45:19 +00:00
from diffusers import EulerDiscreteScheduler # type: ignore
2024-01-10 10:32:40 +00:00
manual_seed ( 0 )
2024-02-22 14:16:22 +00:00
diffusers_prediction_type = " epsilon " if model_prediction_type == ModelPredictionType . NOISE else " sample "
2024-01-10 10:32:40 +00:00
diffusers_scheduler = EulerDiscreteScheduler (
beta_end = 0.012 ,
beta_schedule = " scaled_linear " ,
beta_start = 0.00085 ,
num_train_timesteps = 1000 ,
steps_offset = 1 ,
timestep_spacing = " linspace " ,
use_karras_sigmas = False ,
2024-02-22 14:16:22 +00:00
prediction_type = diffusers_prediction_type ,
2024-01-10 10:32:40 +00:00
)
diffusers_scheduler . set_timesteps ( 30 )
2024-07-10 14:09:45 +00:00
solver = Euler ( num_inference_steps = 30 , params = SolverParams ( model_prediction_type = model_prediction_type ) )
2024-09-25 10:44:35 +00:00
assert torch . equal ( solver . timesteps , diffusers_scheduler . timesteps )
2024-01-10 10:32:40 +00:00
2024-09-25 10:44:35 +00:00
sample = torch . randn ( 1 , 4 , 32 , 32 )
predicted_noise = torch . randn ( 1 , 4 , 32 , 32 )
2024-01-10 10:32:40 +00:00
2024-01-10 15:34:59 +00:00
ref_init_noise_sigma = diffusers_scheduler . init_noise_sigma # type: ignore
assert isinstance ( ref_init_noise_sigma , Tensor )
2024-09-25 10:44:35 +00:00
assert torch . isclose ( ref_init_noise_sigma , solver . init_noise_sigma ) , " init_noise_sigma differ "
2024-01-10 10:32:40 +00:00
for step , timestep in enumerate ( diffusers_scheduler . timesteps ) :
2024-01-20 17:37:49 +00:00
diffusers_output = cast ( Tensor , diffusers_scheduler . step ( predicted_noise , timestep , sample ) . prev_sample ) # type: ignore
2024-07-10 14:09:45 +00:00
refiners_output = solver ( x = sample , predicted_noise = predicted_noise , step = step )
2024-01-10 10:32:40 +00:00
2024-09-25 10:44:35 +00:00
assert torch . allclose ( diffusers_output , refiners_output , rtol = 0.02 ) , f " outputs differ at step { step } "
2024-01-10 10:32:40 +00:00
2024-07-10 14:08:05 +00:00
def test_franken_diffusers ( ) :
from diffusers import EulerDiscreteScheduler # type: ignore
manual_seed ( 0 )
params = {
" beta_end " : 0.012 ,
" beta_schedule " : " scaled_linear " ,
" beta_start " : 0.00085 ,
" num_train_timesteps " : 1000 ,
" steps_offset " : 1 ,
" timestep_spacing " : " linspace " ,
" use_karras_sigmas " : False ,
}
diffusers_scheduler = EulerDiscreteScheduler ( * * params ) # type: ignore
diffusers_scheduler . set_timesteps ( 30 )
diffusers_scheduler_2 = EulerDiscreteScheduler ( * * params ) # type: ignore
2024-07-17 17:04:02 +00:00
solver = FrankenSolver ( lambda : diffusers_scheduler_2 , num_inference_steps = 30 )
2024-09-25 10:44:35 +00:00
assert torch . equal ( solver . timesteps , diffusers_scheduler . timesteps )
2024-07-10 14:08:05 +00:00
2024-09-25 10:44:35 +00:00
sample = torch . randn ( 1 , 4 , 32 , 32 )
predicted_noise = torch . randn ( 1 , 4 , 32 , 32 )
2024-07-10 14:08:05 +00:00
ref_init_noise_sigma = diffusers_scheduler . init_noise_sigma # type: ignore
assert isinstance ( ref_init_noise_sigma , Tensor )
2024-09-25 10:44:35 +00:00
init_noise_sigma = solver . scale_model_input ( torch . tensor ( 1 ) , step = - 1 )
assert torch . equal ( ref_init_noise_sigma , init_noise_sigma ) , " init_noise_sigma differ "
2024-07-10 14:08:05 +00:00
for step , timestep in enumerate ( diffusers_scheduler . timesteps ) :
diffusers_output = cast ( Tensor , diffusers_scheduler . step ( predicted_noise , timestep , sample ) . prev_sample ) # type: ignore
2024-07-10 14:09:45 +00:00
refiners_output = solver ( x = sample , predicted_noise = predicted_noise , step = s tep )
2024-07-10 14:08:05 +00:00
2024-09-25 10:44:35 +00:00
assert torch . equal ( diffusers_output , refiners_output ) , f " outputs differ at step { step } "
2024-07-10 14:08:05 +00:00
2024-02-15 17:41:00 +00:00
def test_lcm_diffusers ( ) :
from diffusers import LCMScheduler # type: ignore
manual_seed ( 0 )
# LCMScheduler is stochastic, make sure we use identical generators
2024-09-25 10:44:35 +00:00
diffusers_generator = torch . Generator ( ) . manual_seed ( 42 )
refiners_generator = torch . Generator ( ) . manual_seed ( 42 )
2024-02-15 17:41:00 +00:00
diffusers_scheduler = LCMScheduler ( )
diffusers_scheduler . set_timesteps ( 4 )
2024-07-10 14:09:45 +00:00
solver = LCMSolver ( num_inference_steps = 4 )
2024-09-25 10:44:35 +00:00
assert torch . equal ( solver . timesteps , diffusers_scheduler . timesteps )
2024-02-15 17:41:00 +00:00
2024-09-25 10:44:35 +00:00
sample = torch . randn ( 1 , 4 , 32 , 32 )
predicted_noise = torch . randn ( 1 , 4 , 32 , 32 )
2024-02-15 17:41:00 +00:00
for step , timestep in enumerate ( diffusers_scheduler . timesteps ) :
alpha_prod_t = diffusers_scheduler . alphas_cumprod [ timestep ]
diffusers_noise_ratio = ( 1 - alpha_prod_t ) . sqrt ( )
diffusers_scale_factor = alpha_prod_t . sqrt ( )
2024-07-10 14:09:45 +00:00
refiners_scale_factor = solver . cumulative_scale_factors [ timestep ]
refiners_noise_ratio = solver . noise_std [ timestep ]
2024-02-15 17:41:00 +00:00
assert refiners_scale_factor == diffusers_scale_factor
assert refiners_noise_ratio == diffusers_noise_ratio
d_out = diffusers_scheduler . step ( predicted_noise , timestep , sample , generator = diffusers_generator ) # type: ignore
diffusers_output = cast ( Tensor , d_out . prev_sample ) # type: ignore
2024-07-10 14:09:45 +00:00
refiners_output = solver (
2024-02-15 17:41:00 +00:00
x = sample ,
predicted_noise = predicted_noise ,
step = step ,
generator = refiners_generator ,
)
2024-09-25 10:44:35 +00:00
assert torch . allclose ( refiners_output , diffusers_output , rtol = 0.01 ) , f " outputs differ at step { step } "
2024-02-15 17:41:00 +00:00
2024-02-22 11:02:58 +00:00
def test_solver_remove_noise ( ) :
2023-10-05 14:44:38 +00:00
from diffusers import DDIMScheduler # type: ignore
2023-12-12 16:18:29 +00:00
manual_seed ( 0 )
2023-10-05 14:44:38 +00:00
diffusers_scheduler = DDIMScheduler (
beta_end = 0.012 ,
beta_schedule = " scaled_linear " ,
beta_start = 0.00085 ,
num_train_timesteps = 1000 ,
steps_offset = 1 ,
clip_sample = False ,
)
diffusers_scheduler . set_timesteps ( 30 )
2024-07-10 14:09:45 +00:00
solver = DDIM ( num_inference_steps = 30 )
2023-10-05 14:44:38 +00:00
2024-09-25 10:44:35 +00:00
sample = torch . randn ( 1 , 4 , 32 , 32 )
noise = torch . randn ( 1 , 4 , 32 , 32 )
2023-10-05 14:44:38 +00:00
for step , timestep in enumerate ( diffusers_scheduler . timesteps ) :
2024-01-10 10:41:47 +00:00
diffusers_output = cast ( Tensor , diffusers_scheduler . step ( noise , timestep , sample ) . pred_original_sample ) # type: ignore
2024-07-10 14:09:45 +00:00
refiners_output = solver . remove_noise ( x = sample , noise = noise , step = step )
2023-10-05 14:44:38 +00:00
2024-09-25 10:44:35 +00:00
assert torch . allclose ( diffusers_output , refiners_output , rtol = 0.01 ) , f " outputs differ at step { step } "
2023-10-05 14:44:38 +00:00
2024-02-22 11:02:58 +00:00
def test_solver_device ( test_device : Device ) :
2023-09-21 09:47:11 +00:00
if test_device . type == " cpu " :
warn ( " not running on CPU, skipping " )
pytest . skip ( )
scheduler = DDIM ( num_inference_steps = 30 , device = test_device )
2024-09-25 10:44:35 +00:00
x = torch . randn ( 1 , 4 , 32 , 32 , device = test_device )
noise = torch . randn ( 1 , 4 , 32 , 32 , device = test_device )
2024-01-19 09:55:04 +00:00
noised = scheduler . add_noise ( x , noise , scheduler . first_inference_step )
2023-09-21 09:47:11 +00:00
assert noised . device == test_device
2024-01-30 16:47:06 +00:00
2024-04-18 09:42:31 +00:00
def test_solver_add_noise ( test_device : Device ) :
scheduler = DDIM ( num_inference_steps = 30 , device = test_device )
2024-09-25 10:44:35 +00:00
latent = torch . randn ( 1 , 4 , 32 , 32 , device = test_device )
noise = torch . randn ( 1 , 4 , 32 , 32 , device = test_device )
2024-04-18 09:42:31 +00:00
noised = scheduler . add_noise (
x = latent ,
noise = noise ,
step = 0 ,
)
noised_double = scheduler . add_noise (
x = latent . repeat ( 2 , 1 , 1 , 1 ) ,
noise = noise . repeat ( 2 , 1 , 1 , 1 ) ,
step = [ 0 , 0 ] ,
)
2024-09-25 10:44:35 +00:00
assert torch . allclose ( noised , noised_double [ 0 ] )
assert torch . allclose ( noised , noised_double [ 1 ] )
2024-04-18 09:42:31 +00:00
2024-01-30 16:47:06 +00:00
@pytest.mark.parametrize ( " noise_schedule " , [ NoiseSchedule . UNIFORM , NoiseSchedule . QUADRATIC , NoiseSchedule . KARRAS ] )
2024-02-22 11:02:58 +00:00
def test_solver_noise_schedules ( noise_schedule : NoiseSchedule , test_device : Device ) :
2024-02-22 14:16:22 +00:00
scheduler = DDIM (
num_inference_steps = 30 ,
params = SolverParams ( noise_schedule = noise_schedule ) ,
device = test_device ,
)
2024-01-30 16:47:06 +00:00
assert len ( scheduler . scale_factors ) == 1000
2024-02-22 14:16:22 +00:00
assert scheduler . scale_factors [ 0 ] == 1 - scheduler . params . initial_diffusion_rate
assert scheduler . scale_factors [ - 1 ] == 1 - scheduler . params . final_diffusion_rate
2024-02-22 11:02:58 +00:00
def test_solver_timestep_spacing ( ) :
# Tests we get the results from [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
linspace_int = Solver . generate_timesteps (
2024-02-22 14:16:22 +00:00
spacing = TimestepSpacing . LINSPACE_ROUNDED ,
2024-02-22 11:02:58 +00:00
num_inference_steps = 10 ,
num_train_timesteps = 1000 ,
offset = 1 ,
)
2024-09-25 10:44:35 +00:00
assert torch . equal ( linspace_int , torch . tensor ( [ 1000 , 889 , 778 , 667 , 556 , 445 , 334 , 223 , 112 , 1 ] ) )
2024-02-22 11:02:58 +00:00
leading = Solver . generate_timesteps (
spacing = TimestepSpacing . LEADING ,
num_inference_steps = 10 ,
num_train_timesteps = 1000 ,
offset = 1 ,
)
2024-09-25 10:44:35 +00:00
assert torch . equal ( leading , torch . tensor ( [ 901 , 801 , 701 , 601 , 501 , 401 , 301 , 201 , 101 , 1 ] ) )
2024-02-22 11:02:58 +00:00
trailing = Solver . generate_timesteps (
spacing = TimestepSpacing . TRAILING ,
num_inference_steps = 10 ,
num_train_timesteps = 1000 ,
offset = 1 ,
)
2024-09-25 10:44:35 +00:00
assert torch . equal ( trailing , torch . tensor ( [ 1000 , 900 , 800 , 700 , 600 , 500 , 400 , 300 , 200 , 100 ] ) )
def test_dpm_bfloat16 ( test_device : Device ) :
if test_device . type == " cpu " :
warn ( " not running on CPU, skipping " )
pytest . skip ( )
2024-09-25 21:17:58 +00:00
n_steps = 5
manual_seed ( 0 )
solver_f32 = DPMSolver ( num_inference_steps = n_steps , dtype = torch . float32 )
solver_bf16 = DPMSolver ( num_inference_steps = n_steps , dtype = torch . bfloat16 )
assert torch . equal ( solver_bf16 . timesteps , solver_f32 . timesteps )