diff --git a/test/BetaSchedules.jl b/test/BetaSchedules.jl index 82dc8d1..e87de92 100644 --- a/test/BetaSchedules.jl +++ b/test/BetaSchedules.jl @@ -2,7 +2,7 @@ using Diffusers.BetaSchedules using Test @testset "Variance schedules tests" begin - @testset "β increases monotonically" begin + @testset "SNR decreases monotonically" begin T = 1000 β_linear = linear_beta_schedule(T) @@ -10,9 +10,20 @@ using Test β_cosine = cosine_beta_schedule(T) β_sigmoid = sigmoid_beta_schedule(T) - @test all(diff(β_linear) .>= 0) - @test all(diff(β_scaled_linear) .>= 0) - @test all(diff(β_cosine) .>= 0) - @test all(diff(β_sigmoid) .>= 0) + α̅_linear = cumprod(1 .- β_linear) + α̅_scaled_linear = cumprod(1 .- β_scaled_linear) + α̅_cosine = cumprod(1 .- β_cosine) + α̅_sigmoid = cumprod(1 .- β_sigmoid) + + # arxiv:2208.11970 (eq. 109) + SNR_linear = α̅_linear ./ (1 .- α̅_linear) + SNR_scaled_linear = α̅_scaled_linear ./ (1 .- α̅_scaled_linear) + SNR_cosine = α̅_cosine ./ (1 .- α̅_cosine) + SNR_sigmoid = α̅_sigmoid ./ (1 .- α̅_sigmoid) + + @test all(diff(SNR_linear) .<= 0) + @test all(diff(SNR_scaled_linear) .<= 0) + @test all(diff(SNR_cosine) .<= 0) + @test all(diff(SNR_sigmoid) .<= 0) end end