🧪 (DDIM) add reverse process correctness tests

This commit is contained in:
Laureηt 2023-09-03 20:30:20 +02:00
parent 02d261f516
commit 7427828eda
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
3 changed files with 88 additions and 63 deletions

View file

@ -94,39 +94,23 @@ function reverse(
xₜ::AbstractArray, xₜ::AbstractArray,
ϵᵧ::AbstractArray, ϵᵧ::AbstractArray,
t::AbstractArray, t::AbstractArray,
;
η::Real=0.0f0, η::Real=0.0f0,
prediction_type::PredictionType=EPSILON, prediction_type::PredictionType=EPSILON
) )
Δₜ = 1 Δₜ = 1
α̅ = _extract(scheduler.α̅, xₜ)
α̅₋ₚ = ShiftedArray(scheduler.α̅, Δₜ, default=1.0f0) α̅₋ₚ = ShiftedArray(scheduler.α̅, Δₜ, default=1.0f0)
α̅ₜ = _extract(α̅[t], xₜ)
α̅ₜ₋ₚ = _extract(α̅₋ₚ[t], xₜ) α̅ₜ₋ₚ = _extract(α̅₋ₚ[t], xₜ)
β̅ = _extract(scheduler.β̅, xₜ)
β̅₋ₚ = ShiftedArray(scheduler.β̅, Δₜ, default=0.0f0)
β̅ₜ = _extract(β̅[t], xₜ)
β̅ₜ₋ₚ = _extract(β̅₋ₚ[t], xₜ)
println("α̅ₜ = ", α̅ₜ)
println("α̅ₜ₋ₚ = ", α̅ₜ₋ₚ)
# compute x₀ (approximation) # compute x₀ (approximation)
x̂₀ = get_prediction(scheduler, prediction_type, xₜ, ϵᵧ, t) x̂₀, ϵ̂ = get_prediction(scheduler, prediction_type, xₜ, ϵᵧ, t)
# clip between -1 and 1
x̂₀ = clamp.(x̂₀, -3, 3)
# compute σ (exact) # compute σ (exact)
σ²ₜ = (β̅ₜ₋ₚ ./ β̅ₜ) .* (1 .- α̅ₜ ./ α̅ₜ₋ₚ) σ²ₜ = get_variance(scheduler, xₜ, t, Δₜ)
println("β̅ₜ₋ₚ = ", β̅ₜ₋ₚ)
println("β̅ₜ = ", β̅ₜ)
println("α̅ₜ = ", α̅ₜ)
println("α̅ₜ₋ₚ = ", α̅ₜ₋ₚ)
println("σ²ₜ = ", σ²ₜ)
σₜ = η * sqrt.(σ²ₜ) σₜ = η * sqrt.(σ²ₜ)
# compute direction # compute direction
Δₓ = sqrt.(1 .- α̅ₜ₋ₚ .- σₜ^2) .* ϵᵧ Δₓ = sqrt.(1 .- α̅ₜ₋ₚ .- σₜ .^ 2) .* ϵ̂
# sample xₜ₋ₚ # sample xₜ₋ₚ
ϵ = randn(Float32, size(xₜ)) ϵ = randn(Float32, size(xₜ))
@ -160,18 +144,37 @@ function get_prediction(
⎷β̅ₜ = _extract(scheduler.⎷β̅[t], xₜ) ⎷β̅ₜ = _extract(scheduler.⎷β̅[t], xₜ)
if prediction_type == EPSILON if prediction_type == EPSILON
# arxiv:2006.11239 Eq. 15
# arxiv:2208.11970 Eq. 115
x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ
# elseif prediction_type == SAMPLE ϵ̂ = ϵᵧ
# # arxiv:2208.11970 Eq. 99 elseif prediction_type == SAMPLE
# x̂₀ = ϵᵧ x̂₀ = ϵᵧ
# elseif prediction_type == VELOCITY ϵ̂ = (xₜ - ⎷α̅ₜ .* x̂₀) ./ ⎷β̅ₜ
# # arxiv:2202.00512 Eq. 31 elseif prediction_type == VELOCITY
# x̂₀ = ⎷α̅ₜ .* xₜ - ⎷β̅ₜ .* ϵᵧ x̂₀ = ⎷α̅ₜ .* xₜ .- ⎷β̅ₜ .* ϵᵧ
ϵ̂ = ⎷α̅ₜ .* ϵᵧ .+ ⎷β̅ₜ .* xₜ
else else
throw("unimplemented prediction type") throw("unimplemented prediction type")
end end
return x̂₀ return x̂₀, ϵ̂
end
function get_variance(
scheduler::DDIM,
xₜ::AbstractArray,
t::AbstractArray,
Δₜ::Integer,
)
α̅ = _extract(scheduler.α̅, xₜ)
α̅₋ₚ = ShiftedArray(scheduler.α̅, Δₜ, default=1.0f0)
α̅ₜ = _extract(α̅[t], xₜ)
α̅ₜ₋ₚ = _extract(α̅₋ₚ[t], xₜ)
β̅ = _extract(scheduler.β̅, xₜ)
β̅₋ₚ = ShiftedArray(scheduler.β̅, Δₜ, default=0.0f0)
β̅ₜ = _extract(β̅[t], xₜ)
β̅ₜ₋ₚ = _extract(β̅₋ₚ[t], xₜ)
σ²ₜ = (β̅ₜ₋ₚ ./ β̅ₜ) .* (1 .- α̅ₜ ./ α̅ₜ₋ₚ)
return σ²ₜ
end end

View file

@ -94,8 +94,9 @@ function reverse(
xₜ::AbstractArray, xₜ::AbstractArray,
ϵᵧ::AbstractArray, ϵᵧ::AbstractArray,
t::AbstractArray, t::AbstractArray,
;
prediction_type::PredictionType=EPSILON, prediction_type::PredictionType=EPSILON,
variance_type::VarianceType=FIXED_SMALL, variance_type::VarianceType=FIXED_SMALL
) )
# retreive scheduler variables at timesteps t # retreive scheduler variables at timesteps t
βₜ = _extract(scheduler.β[t], xₜ) βₜ = _extract(scheduler.β[t], xₜ)

View file

@ -1,4 +1,4 @@
import Diffusers: reverse, forward, get_velocity, DDPM, cosine_beta_schedule, VELOCITY import Diffusers: reverse, forward, get_velocity, DDPM, cosine_beta_schedule, VELOCITY, EPSILON, SAMPLE
import Statistics: mean, std import Statistics: mean, std
using Test using Test
@ -7,43 +7,64 @@ using Test
@testset "check `reverse` correctness" begin @testset "check `reverse` correctness" begin
T = 10 T = 10
batch_size = 8 batch_size = 8
size = 128 data_size = 128
# create a DDPM with a cosine beta schedule # create a DDPM with a cosine beta schedule
ddpm = DDPM(cosine_beta_schedule(T)) for scheduler_type in [DDPM, DDIM]
# create some dummy data scheduler = scheduler_type(cosine_beta_schedule(T))
x₀ = ones(Float32, size, size, batch_size) @testset "Scheduler == $scheduler_type" begin
ϵ = randn(Float32, size, size, batch_size)
# create some dummy data
x₀ = ones(Float32, data_size, data_size, batch_size)
ϵ = randn(Float32, data_size, data_size, batch_size)
@testset "PredictionType == EPSILON" begin
for t in 1:T
t = ones(UInt32, batch_size) .* t
# get xₜ from forward diffusion process
xₜ = forward(scheduler, x₀, ϵ, t)
# suppose a model predicted ϵ perfectly
ϵᵧ = ϵ
# use reverse diffusion process to retreive x̂₀
_, x̂₀ = reverse(scheduler, xₜ, ϵᵧ, t; prediction_type=EPSILON)
# test that we recover x₀
@test x̂₀ x₀
end
end
@testset "PredictionType == SAMPLE" begin
for t in 1:T
t = ones(UInt32, batch_size) .* t
# get xₜ from forward diffusion process
xₜ = forward(scheduler, x₀, ϵ, t)
# suppose a model predicted x₀ perfectly
ϵᵧ = x₀
# use reverse diffusion process to retreive x̂₀
_, x̂₀ = reverse(scheduler, xₜ, ϵᵧ, t; prediction_type=SAMPLE)
# test that we recover x₀
@test x̂₀ x₀
end
end
@testset "PredictionType == VELOCITY" begin
for t in 1:T
t = ones(UInt32, batch_size) .* t
# get xₜ from forward diffusion process
xₜ = forward(scheduler, x₀, ϵ, t)
# compute vₜ to train model
vₜ = get_velocity(scheduler, x₀, ϵ, t)
# suppose a model predicted vₜ perfectly
vᵧ = vₜ
# use reverse diffusion process to retreive x̂₀
_, x̂₀ = Diffusers.reverse(scheduler, xₜ, vₜ, t; prediction_type=VELOCITY)
# test that we recover x₀
@test x̂₀ x₀
end
end
@testset "PredictionType == EPSILON" begin
for t in 1:T
t = ones(UInt32, batch_size) .* t
# get xₜ from forward diffusion process
xₜ = forward(ddpm, x₀, ϵ, t)
# suppose a model predicted ϵ perfectly
ϵᵧ = ϵ
# use reverse diffusion process to retreive x̂₀
_, x̂₀ = reverse(ddpm, xₜ, ϵᵧ, t)
# test that we recover x₀
@test x̂₀ x₀
end end
end
@testset "PredictionType == VELOCITY" begin
for t in 1:T
t = ones(UInt32, batch_size) .* t
# get xₜ from forward diffusion process
xₜ = forward(ddpm, x₀, ϵ, t)
# compute vₜ to train model
vₜ = get_velocity(ddpm, x₀, ϵ, t)
# suppose a model predicted vₜ perfectly
vᵧ = vₜ
# use reverse diffusion process to retreive x̂₀
_, x̂₀ = Diffusers.reverse(ddpm, xₜ, vₜ, t, VELOCITY)
# test that we recover x₀
@test x̂₀ x₀
end
end end
end end