2023-08-14 15:19:43 +00:00
|
|
|
import Diffusers: reverse, forward, get_velocity, DDPM, cosine_beta_schedule, VELOCITY
|
2023-08-14 14:29:39 +00:00
|
|
|
import Statistics: mean, std
|
2023-08-08 18:40:57 +00:00
|
|
|
using Test
|
|
|
|
|
|
|
|
@testset "Schedulers tests" begin
|
2023-08-14 15:19:43 +00:00
|
|
|
|
2023-08-14 14:26:15 +00:00
|
|
|
@testset "check `reverse` correctness" begin
|
2023-08-08 18:40:57 +00:00
|
|
|
T = 10
|
|
|
|
batch_size = 8
|
|
|
|
size = 128
|
|
|
|
|
|
|
|
# create a DDPM with a cosine beta schedule
|
2023-08-14 14:29:39 +00:00
|
|
|
ddpm = DDPM(cosine_beta_schedule(T))
|
2023-08-08 18:40:57 +00:00
|
|
|
|
|
|
|
# create some dummy data
|
|
|
|
x₀ = ones(Float32, size, size, batch_size)
|
|
|
|
ϵ = randn(Float32, size, size, batch_size)
|
|
|
|
|
2023-08-14 15:19:43 +00:00
|
|
|
@testset "PredictionType == EPSILON" begin
|
|
|
|
for t in 1:T
|
|
|
|
t = ones(UInt32, batch_size) .* t
|
|
|
|
# get xₜ from forward diffusion process
|
|
|
|
xₜ = forward(ddpm, x₀, ϵ, t)
|
|
|
|
# suppose a model predicted ϵ perfectly
|
|
|
|
ϵᵧ = ϵ
|
|
|
|
# use reverse diffusion process to retreive x̂₀
|
|
|
|
_, x̂₀ = reverse(ddpm, xₜ, ϵᵧ, t)
|
|
|
|
# test that we recover x₀
|
|
|
|
@test x̂₀ ≈ x₀
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
@testset "PredictionType == VELOCITY" begin
|
|
|
|
for t in 1:T
|
|
|
|
t = ones(UInt32, batch_size) .* t
|
|
|
|
# get xₜ from forward diffusion process
|
|
|
|
xₜ = forward(ddpm, x₀, ϵ, t)
|
|
|
|
# compute vₜ to train model
|
|
|
|
vₜ = get_velocity(ddpm, x₀, ϵ, t)
|
|
|
|
# suppose a model predicted vₜ perfectly
|
|
|
|
vᵧ = vₜ
|
|
|
|
# use reverse diffusion process to retreive x̂₀
|
|
|
|
_, x̂₀ = Diffusers.reverse(ddpm, xₜ, vₜ, t, VELOCITY)
|
|
|
|
# test that we recover x₀
|
|
|
|
@test x̂₀ ≈ x₀
|
|
|
|
end
|
2023-08-08 18:40:57 +00:00
|
|
|
end
|
2023-08-14 15:19:43 +00:00
|
|
|
|
2023-08-08 18:40:57 +00:00
|
|
|
end
|
|
|
|
|
2023-08-14 14:26:15 +00:00
|
|
|
@testset "check `forward` terminal SNR" begin
|
2023-08-08 18:40:57 +00:00
|
|
|
T = 10
|
|
|
|
batch_size = 1
|
2023-08-08 18:58:30 +00:00
|
|
|
size = 2500
|
2023-08-08 18:40:57 +00:00
|
|
|
|
|
|
|
# create a DDPM with a terminal SNR cosine beta schedule
|
2023-08-14 14:29:39 +00:00
|
|
|
ddpm = DDPM(cosine_beta_schedule(T))
|
2023-08-08 18:40:57 +00:00
|
|
|
|
|
|
|
# create some dummy data
|
|
|
|
x₀ = ones(Float32, size, size, batch_size)
|
|
|
|
ϵ = randn(Float32, size, size, batch_size)
|
|
|
|
|
|
|
|
t = ones(UInt32, batch_size) .* T
|
|
|
|
# corrupt x₀ with noise
|
2023-08-14 14:26:15 +00:00
|
|
|
xₜ = forward(ddpm, x₀, ϵ, t)
|
2023-08-08 18:40:57 +00:00
|
|
|
|
2023-08-08 18:45:04 +00:00
|
|
|
@test std(xₜ) ≈ 1.0 atol = 1.0f-3
|
|
|
|
@test mean(xₜ) ≈ 0.0 atol = 1.0f-2
|
2023-08-08 18:40:57 +00:00
|
|
|
end
|
2023-08-14 15:19:43 +00:00
|
|
|
|
2023-08-08 18:40:57 +00:00
|
|
|
end
|