🧪 (DDIM) add reverse process correctness tests

This commit is contained in:
Laureηt 2023-09-03 20:30:20 +02:00
parent 02d261f516
commit 7427828eda
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
3 changed files with 88 additions and 63 deletions

View file

@ -94,39 +94,23 @@ function reverse(
xₜ::AbstractArray,
ϵᵧ::AbstractArray,
t::AbstractArray,
;
η::Real=0.0f0,
prediction_type::PredictionType=EPSILON,
prediction_type::PredictionType=EPSILON
)
Δₜ = 1
α̅ = _extract(scheduler.α̅, xₜ)
α̅₋ₚ = ShiftedArray(scheduler.α̅, Δₜ, default=1.0f0)
α̅ₜ = _extract(α̅[t], xₜ)
α̅ₜ₋ₚ = _extract(α̅₋ₚ[t], xₜ)
β̅ = _extract(scheduler.β̅, xₜ)
β̅₋ₚ = ShiftedArray(scheduler.β̅, Δₜ, default=0.0f0)
β̅ₜ = _extract(β̅[t], xₜ)
β̅ₜ₋ₚ = _extract(β̅₋ₚ[t], xₜ)
println("α̅ₜ = ", α̅ₜ)
println("α̅ₜ₋ₚ = ", α̅ₜ₋ₚ)
# compute x₀ (approximation)
x̂₀ = get_prediction(scheduler, prediction_type, xₜ, ϵᵧ, t)
# clip between -1 and 1
x̂₀ = clamp.(x̂₀, -3, 3)
x̂₀, ϵ̂ = get_prediction(scheduler, prediction_type, xₜ, ϵᵧ, t)
# compute σ (exact)
σ²ₜ = (β̅ₜ₋ₚ ./ β̅ₜ) .* (1 .- α̅ₜ ./ α̅ₜ₋ₚ)
println("β̅ₜ₋ₚ = ", β̅ₜ₋ₚ)
println("β̅ₜ = ", β̅ₜ)
println("α̅ₜ = ", α̅ₜ)
println("α̅ₜ₋ₚ = ", α̅ₜ₋ₚ)
println("σ²ₜ = ", σ²ₜ)
σ²ₜ = get_variance(scheduler, xₜ, t, Δₜ)
σₜ = η * sqrt.(σ²ₜ)
# compute direction
Δₓ = sqrt.(1 .- α̅ₜ₋ₚ .- σₜ^2) .* ϵᵧ
Δₓ = sqrt.(1 .- α̅ₜ₋ₚ .- σₜ .^ 2) .* ϵ̂
# sample xₜ₋ₚ
ϵ = randn(Float32, size(xₜ))
@ -160,18 +144,37 @@ function get_prediction(
⎷β̅ₜ = _extract(scheduler.⎷β̅[t], xₜ)
if prediction_type == EPSILON
# arxiv:2006.11239 Eq. 15
# arxiv:2208.11970 Eq. 115
x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ
# elseif prediction_type == SAMPLE
# # arxiv:2208.11970 Eq. 99
# x̂₀ = ϵᵧ
# elseif prediction_type == VELOCITY
# # arxiv:2202.00512 Eq. 31
# x̂₀ = ⎷α̅ₜ .* xₜ - ⎷β̅ₜ .* ϵᵧ
ϵ̂ = ϵᵧ
elseif prediction_type == SAMPLE
x̂₀ = ϵᵧ
ϵ̂ = (xₜ - ⎷α̅ₜ .* x̂₀) ./ ⎷β̅ₜ
elseif prediction_type == VELOCITY
x̂₀ = ⎷α̅ₜ .* xₜ .- ⎷β̅ₜ .* ϵᵧ
ϵ̂ = ⎷α̅ₜ .* ϵᵧ .+ ⎷β̅ₜ .* xₜ
else
throw("unimplemented prediction type")
end
return x̂₀
return x̂₀, ϵ̂
end
function get_variance(
scheduler::DDIM,
xₜ::AbstractArray,
t::AbstractArray,
Δₜ::Integer,
)
α̅ = _extract(scheduler.α̅, xₜ)
α̅₋ₚ = ShiftedArray(scheduler.α̅, Δₜ, default=1.0f0)
α̅ₜ = _extract(α̅[t], xₜ)
α̅ₜ₋ₚ = _extract(α̅₋ₚ[t], xₜ)
β̅ = _extract(scheduler.β̅, xₜ)
β̅₋ₚ = ShiftedArray(scheduler.β̅, Δₜ, default=0.0f0)
β̅ₜ = _extract(β̅[t], xₜ)
β̅ₜ₋ₚ = _extract(β̅₋ₚ[t], xₜ)
σ²ₜ = (β̅ₜ₋ₚ ./ β̅ₜ) .* (1 .- α̅ₜ ./ α̅ₜ₋ₚ)
return σ²ₜ
end

View file

@ -94,8 +94,9 @@ function reverse(
xₜ::AbstractArray,
ϵᵧ::AbstractArray,
t::AbstractArray,
;
prediction_type::PredictionType=EPSILON,
variance_type::VarianceType=FIXED_SMALL,
variance_type::VarianceType=FIXED_SMALL
)
# retreive scheduler variables at timesteps t
βₜ = _extract(scheduler.β[t], xₜ)

View file

@ -1,4 +1,4 @@
import Diffusers: reverse, forward, get_velocity, DDPM, cosine_beta_schedule, VELOCITY
import Diffusers: reverse, forward, get_velocity, DDPM, cosine_beta_schedule, VELOCITY, EPSILON, SAMPLE
import Statistics: mean, std
using Test
@ -7,24 +7,41 @@ using Test
@testset "check `reverse` correctness" begin
T = 10
batch_size = 8
size = 128
data_size = 128
# create a DDPM with a cosine beta schedule
ddpm = DDPM(cosine_beta_schedule(T))
for scheduler_type in [DDPM, DDIM]
scheduler = scheduler_type(cosine_beta_schedule(T))
@testset "Scheduler == $scheduler_type" begin
# create some dummy data
x₀ = ones(Float32, size, size, batch_size)
ϵ = randn(Float32, size, size, batch_size)
x₀ = ones(Float32, data_size, data_size, batch_size)
ϵ = randn(Float32, data_size, data_size, batch_size)
@testset "PredictionType == EPSILON" begin
for t in 1:T
t = ones(UInt32, batch_size) .* t
# get xₜ from forward diffusion process
xₜ = forward(ddpm, x₀, ϵ, t)
xₜ = forward(scheduler, x₀, ϵ, t)
# suppose a model predicted ϵ perfectly
ϵᵧ = ϵ
# use reverse diffusion process to retreive x̂₀
_, x̂₀ = reverse(ddpm, xₜ, ϵᵧ, t)
_, x̂₀ = reverse(scheduler, xₜ, ϵᵧ, t; prediction_type=EPSILON)
# test that we recover x₀
@test x̂₀ x₀
end
end
@testset "PredictionType == SAMPLE" begin
for t in 1:T
t = ones(UInt32, batch_size) .* t
# get xₜ from forward diffusion process
xₜ = forward(scheduler, x₀, ϵ, t)
# suppose a model predicted x₀ perfectly
ϵᵧ = x₀
# use reverse diffusion process to retreive x̂₀
_, x̂₀ = reverse(scheduler, xₜ, ϵᵧ, t; prediction_type=SAMPLE)
# test that we recover x₀
@test x̂₀ x₀
end
@ -34,13 +51,13 @@ using Test
for t in 1:T
t = ones(UInt32, batch_size) .* t
# get xₜ from forward diffusion process
xₜ = forward(ddpm, x₀, ϵ, t)
xₜ = forward(scheduler, x₀, ϵ, t)
# compute vₜ to train model
vₜ = get_velocity(ddpm, x₀, ϵ, t)
vₜ = get_velocity(scheduler, x₀, ϵ, t)
# suppose a model predicted vₜ perfectly
vᵧ = vₜ
# use reverse diffusion process to retreive x̂₀
_, x̂₀ = Diffusers.reverse(ddpm, xₜ, vₜ, t, VELOCITY)
_, x̂₀ = Diffusers.reverse(scheduler, xₜ, vₜ, t; prediction_type=VELOCITY)
# test that we recover x₀
@test x̂₀ x₀
end
@ -48,6 +65,10 @@ using Test
end
end
end
@testset "check `forward` terminal SNR" begin
T = 10
batch_size = 1