From 9896a1665e68d9dadbf0a929f312dc7da1698ef1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Sun, 3 Sep 2023 20:54:35 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=AA=20(beta=5Fschedules)=20add=20ZeroS?= =?UTF-8?q?NR=20rescaling=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/BetaSchedules.jl | 61 +++++++++++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/test/BetaSchedules.jl b/test/BetaSchedules.jl index b90ad4c..9285eb4 100644 --- a/test/BetaSchedules.jl +++ b/test/BetaSchedules.jl @@ -2,32 +2,49 @@ using Diffusers.BetaSchedules using Test @testset "Variance schedules tests" begin + @testset "SNR decreases monotonically" begin T = 1000 - β_linear = linear_beta_schedule(T) - β_scaled_linear = scaled_linear_beta_schedule(T) - β_cosine = cosine_beta_schedule(T) - β_sigmoid = sigmoid_beta_schedule(T) - β_exponential = exponential_beta_schedule(T) + for beta_schedule_type in [ + linear_beta_schedule, + scaled_linear_beta_schedule, + cosine_beta_schedule, + sigmoid_beta_schedule, + exponential_beta_schedule, + ] + @testset "Variance schedule == $beta_schedule_type" begin + β = beta_schedule_type(T) + α = 1 .- β + α̅ = cumprod(α) - α̅_linear = cumprod(1 .- β_linear) - α̅_scaled_linear = cumprod(1 .- β_scaled_linear) - α̅_cosine = cumprod(1 .- β_cosine) - α̅_sigmoid = cumprod(1 .- β_sigmoid) - α̅_exponential = cumprod(1 .- β_exponential) + SNR = α̅ ./ (1 .- α̅) - # arxiv:2208.11970 Eq. 109 - SNR_linear = α̅_linear ./ (1 .- α̅_linear) - SNR_scaled_linear = α̅_scaled_linear ./ (1 .- α̅_scaled_linear) - SNR_cosine = α̅_cosine ./ (1 .- α̅_cosine) - SNR_sigmoid = α̅_sigmoid ./ (1 .- α̅_sigmoid) - SNR_exponential = α̅_exponential ./ (1 .- α̅_exponential) - - @test all(diff(SNR_linear) .<= 0) - @test all(diff(SNR_scaled_linear) .<= 0) - @test all(diff(SNR_cosine) .<= 0) - @test all(diff(SNR_sigmoid) .<= 0) - @test all(diff(SNR_exponential) .<= 0) + @test all(diff(SNR) .<= 0) + end + end end + + @testset "ZeroSNR rescaling" begin + T = 1000 + + for beta_schedule_type in [ + linear_beta_schedule, + scaled_linear_beta_schedule, + cosine_beta_schedule, + sigmoid_beta_schedule, + exponential_beta_schedule, + ] + @testset "Variance schedule == $beta_schedule_type" begin + β = rescale_zero_terminal_snr(beta_schedule_type(T)) + α = 1 .- β + α̅ = cumprod(α) + + SNR = α̅ ./ (1 .- α̅) + + @test SNR[end] ≈ 0 + end + end + end + end