🧪 (Schedulers) add get_velocity tests

This commit is contained in:
Laureηt 2023-08-14 17:19:43 +02:00
parent b90b50de27
commit 184b9665ec
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI

View file

@ -1,8 +1,9 @@
import Diffusers: reverse, forward, DDPM, cosine_beta_schedule
import Diffusers: reverse, forward, get_velocity, DDPM, cosine_beta_schedule, VELOCITY
import Statistics: mean, std
using Test
@testset "Schedulers tests" begin
@testset "check `reverse` correctness" begin
T = 10
batch_size = 8
@ -15,15 +16,36 @@ using Test
x₀ = ones(Float32, size, size, batch_size)
ϵ = randn(Float32, size, size, batch_size)
for t in 1:T
t = ones(UInt32, batch_size) .* t
# corrupt x₀ with noise
xₜ = forward(ddpm, x₀, ϵ, t)
# suppose a model predicted ϵ perfectly
_, x̂₀ = reverse(ddpm, xₜ, ϵ, t)
# test that we recover x₀
@test x̂₀ x₀
@testset "PredictionType == EPSILON" begin
for t in 1:T
t = ones(UInt32, batch_size) .* t
# get xₜ from forward diffusion process
xₜ = forward(ddpm, x₀, ϵ, t)
# suppose a model predicted ϵ perfectly
ϵᵧ = ϵ
# use reverse diffusion process to retreive x̂₀
_, x̂₀ = reverse(ddpm, xₜ, ϵᵧ, t)
# test that we recover x₀
@test x̂₀ x₀
end
end
@testset "PredictionType == VELOCITY" begin
for t in 1:T
t = ones(UInt32, batch_size) .* t
# get xₜ from forward diffusion process
xₜ = forward(ddpm, x₀, ϵ, t)
# compute vₜ to train model
vₜ = get_velocity(ddpm, x₀, ϵ, t)
# suppose a model predicted vₜ perfectly
vᵧ = vₜ
# use reverse diffusion process to retreive x̂₀
_, x̂₀ = Diffusers.reverse(ddpm, xₜ, vₜ, t, VELOCITY)
# test that we recover x₀
@test x̂₀ x₀
end
end
end
@testset "check `forward` terminal SNR" begin
@ -45,4 +67,5 @@ using Test
@test std(xₜ) 1.0 atol = 1.0f-3
@test mean(xₜ) 0.0 atol = 1.0f-2
end
end