diff --git a/test/BetaSchedules.jl b/test/BetaSchedules.jl index b90ad4c..9285eb4 100644 --- a/test/BetaSchedules.jl +++ b/test/BetaSchedules.jl @@ -2,32 +2,49 @@ using Diffusers.BetaSchedules using Test @testset "Variance schedules tests" begin + @testset "SNR decreases monotonically" begin T = 1000 - β_linear = linear_beta_schedule(T) - β_scaled_linear = scaled_linear_beta_schedule(T) - β_cosine = cosine_beta_schedule(T) - β_sigmoid = sigmoid_beta_schedule(T) - β_exponential = exponential_beta_schedule(T) + for beta_schedule_type in [ + linear_beta_schedule, + scaled_linear_beta_schedule, + cosine_beta_schedule, + sigmoid_beta_schedule, + exponential_beta_schedule, + ] + @testset "Variance schedule == $beta_schedule_type" begin + β = beta_schedule_type(T) + α = 1 .- β + α̅ = cumprod(α) - α̅_linear = cumprod(1 .- β_linear) - α̅_scaled_linear = cumprod(1 .- β_scaled_linear) - α̅_cosine = cumprod(1 .- β_cosine) - α̅_sigmoid = cumprod(1 .- β_sigmoid) - α̅_exponential = cumprod(1 .- β_exponential) + SNR = α̅ ./ (1 .- α̅) - # arxiv:2208.11970 Eq. 109 - SNR_linear = α̅_linear ./ (1 .- α̅_linear) - SNR_scaled_linear = α̅_scaled_linear ./ (1 .- α̅_scaled_linear) - SNR_cosine = α̅_cosine ./ (1 .- α̅_cosine) - SNR_sigmoid = α̅_sigmoid ./ (1 .- α̅_sigmoid) - SNR_exponential = α̅_exponential ./ (1 .- α̅_exponential) - - @test all(diff(SNR_linear) .<= 0) - @test all(diff(SNR_scaled_linear) .<= 0) - @test all(diff(SNR_cosine) .<= 0) - @test all(diff(SNR_sigmoid) .<= 0) - @test all(diff(SNR_exponential) .<= 0) + @test all(diff(SNR) .<= 0) + end + end end + + @testset "ZeroSNR rescaling" begin + T = 1000 + + for beta_schedule_type in [ + linear_beta_schedule, + scaled_linear_beta_schedule, + cosine_beta_schedule, + sigmoid_beta_schedule, + exponential_beta_schedule, + ] + @testset "Variance schedule == $beta_schedule_type" begin + β = rescale_zero_terminal_snr(beta_schedule_type(T)) + α = 1 .- β + α̅ = cumprod(α) + + SNR = α̅ ./ (1 .- α̅) + + @test SNR[end] ≈ 0 + end + end + end + end