From 8b098c4c0933529f890a5c41aa12b516cdb9d350 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Tue, 8 Aug 2023 20:40:57 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=AA=20(Schedulers)=20add=20some=20test?= =?UTF-8?q?s=20to=20check=20`step`=20and=20`add=5Fnoise`=20correctness?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Project.toml | 3 ++- test/Schedulers.jl | 58 ++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 test/Schedulers.jl diff --git a/Project.toml b/Project.toml index acfbe43..6590e19 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,8 @@ julia = "1" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Aqua"] +test = ["Test", "Aqua", "Statistics"] diff --git a/test/Schedulers.jl b/test/Schedulers.jl new file mode 100644 index 0000000..a59c00a --- /dev/null +++ b/test/Schedulers.jl @@ -0,0 +1,58 @@ +import Diffusers: step, add_noise, DDPM, cosine_beta_schedule, rescale_zero_terminal_snr +using Statistics +using Test + +@testset "Schedulers tests" begin + @testset "check `step` correctness" begin + T = 10 + batch_size = 8 + size = 128 + + # create a DDPM with a cosine beta schedule + ddpm = Diffusers.DDPM( + Vector{Float32}, + Diffusers.cosine_beta_schedule(T), + ) + + # create some dummy data + x₀ = ones(Float32, size, size, batch_size) + ϵ = randn(Float32, size, size, batch_size) + + for t in 1:T + t = ones(UInt32, batch_size) .* t + # corrupt x₀ with noise + xₜ = Diffusers.add_noise(ddpm, x₀, ϵ, t) + # suppose a model predicted ϵ perfectly + _, x̂₀ = Diffusers.step(ddpm, xₜ, ϵ, t) + # test that we recover x₀ + @test x̂₀ ≈ x₀ + end + end + + @testset "check `add_noise` terminal SNR" begin + T = 10 + batch_size = 1 + size = 1000 + + # create a DDPM with a terminal SNR cosine beta schedule + ddpm = Diffusers.DDPM( + Vector{Float32}, + Diffusers.rescale_zero_terminal_snr( + Diffusers.cosine_beta_schedule(T), + ), + ) + + # create some dummy data + x₀ = ones(Float32, size, size, batch_size) + ϵ = randn(Float32, size, size, batch_size) + + t = ones(UInt32, batch_size) .* T + # corrupt x₀ with noise + xₜ = Diffusers.add_noise(ddpm, x₀, ϵ, t) + + @test std(xₜ) ≈ 1.0 atol=1f-3 + @test mean(xₜ) ≈ 0.0 atol=1f-3 + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 3a7ab0d..6a746e4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,4 +3,5 @@ using Aqua Aqua.test_all(Diffusers) +include("Schedulers.jl") include("BetaSchedules.jl")