🧪 (Schedulers) add some tests to check step and add_noise correctness

This commit is contained in:
Laureηt 2023-08-08 20:40:57 +02:00
parent 77961913ba
commit 8b098c4c09
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
3 changed files with 61 additions and 1 deletions

View file

@ -13,7 +13,8 @@ julia = "1"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
test = ["Test", "Aqua"]
test = ["Test", "Aqua", "Statistics"]

58
test/Schedulers.jl Normal file
View file

@ -0,0 +1,58 @@
import Diffusers: step, add_noise, DDPM, cosine_beta_schedule, rescale_zero_terminal_snr
using Statistics
using Test
@testset "Schedulers tests" begin
@testset "check `step` correctness" begin
T = 10
batch_size = 8
size = 128
# create a DDPM with a cosine beta schedule
ddpm = Diffusers.DDPM(
Vector{Float32},
Diffusers.cosine_beta_schedule(T),
)
# create some dummy data
x₀ = ones(Float32, size, size, batch_size)
ϵ = randn(Float32, size, size, batch_size)
for t in 1:T
t = ones(UInt32, batch_size) .* t
# corrupt x₀ with noise
xₜ = Diffusers.add_noise(ddpm, x₀, ϵ, t)
# suppose a model predicted ϵ perfectly
_, x̂₀ = Diffusers.step(ddpm, xₜ, ϵ, t)
# test that we recover x₀
@test x̂₀ x₀
end
end
@testset "check `add_noise` terminal SNR" begin
T = 10
batch_size = 1
size = 1000
# create a DDPM with a terminal SNR cosine beta schedule
ddpm = Diffusers.DDPM(
Vector{Float32},
Diffusers.rescale_zero_terminal_snr(
Diffusers.cosine_beta_schedule(T),
),
)
# create some dummy data
x₀ = ones(Float32, size, size, batch_size)
ϵ = randn(Float32, size, size, batch_size)
t = ones(UInt32, batch_size) .* T
# corrupt x₀ with noise
xₜ = Diffusers.add_noise(ddpm, x₀, ϵ, t)
@test std(xₜ) 1.0 atol=1f-3
@test mean(xₜ) 0.0 atol=1f-3
end
end
end
end

View file

@ -3,4 +3,5 @@ using Aqua
Aqua.test_all(Diffusers)
include("Schedulers.jl")
include("BetaSchedules.jl")