mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-10-18 07:36:20 +00:00
🧪 (DDIM) add reverse process correctness tests
This commit is contained in:
parent
02d261f516
commit
7427828eda
|
@ -94,39 +94,23 @@ function reverse(
|
||||||
xₜ::AbstractArray,
|
xₜ::AbstractArray,
|
||||||
ϵᵧ::AbstractArray,
|
ϵᵧ::AbstractArray,
|
||||||
t::AbstractArray,
|
t::AbstractArray,
|
||||||
|
;
|
||||||
η::Real=0.0f0,
|
η::Real=0.0f0,
|
||||||
prediction_type::PredictionType=EPSILON,
|
prediction_type::PredictionType=EPSILON
|
||||||
)
|
)
|
||||||
Δₜ = 1
|
Δₜ = 1
|
||||||
α̅ = _extract(scheduler.α̅, xₜ)
|
|
||||||
α̅₋ₚ = ShiftedArray(scheduler.α̅, Δₜ, default=1.0f0)
|
α̅₋ₚ = ShiftedArray(scheduler.α̅, Δₜ, default=1.0f0)
|
||||||
α̅ₜ = _extract(α̅[t], xₜ)
|
|
||||||
α̅ₜ₋ₚ = _extract(α̅₋ₚ[t], xₜ)
|
α̅ₜ₋ₚ = _extract(α̅₋ₚ[t], xₜ)
|
||||||
β̅ = _extract(scheduler.β̅, xₜ)
|
|
||||||
β̅₋ₚ = ShiftedArray(scheduler.β̅, Δₜ, default=0.0f0)
|
|
||||||
β̅ₜ = _extract(β̅[t], xₜ)
|
|
||||||
β̅ₜ₋ₚ = _extract(β̅₋ₚ[t], xₜ)
|
|
||||||
|
|
||||||
println("α̅ₜ = ", α̅ₜ)
|
|
||||||
println("α̅ₜ₋ₚ = ", α̅ₜ₋ₚ)
|
|
||||||
|
|
||||||
# compute x₀ (approximation)
|
# compute x₀ (approximation)
|
||||||
x̂₀ = get_prediction(scheduler, prediction_type, xₜ, ϵᵧ, t)
|
x̂₀, ϵ̂ = get_prediction(scheduler, prediction_type, xₜ, ϵᵧ, t)
|
||||||
|
|
||||||
# clip between -1 and 1
|
|
||||||
x̂₀ = clamp.(x̂₀, -3, 3)
|
|
||||||
|
|
||||||
# compute σ (exact)
|
# compute σ (exact)
|
||||||
σ²ₜ = (β̅ₜ₋ₚ ./ β̅ₜ) .* (1 .- α̅ₜ ./ α̅ₜ₋ₚ)
|
σ²ₜ = get_variance(scheduler, xₜ, t, Δₜ)
|
||||||
println("β̅ₜ₋ₚ = ", β̅ₜ₋ₚ)
|
|
||||||
println("β̅ₜ = ", β̅ₜ)
|
|
||||||
println("α̅ₜ = ", α̅ₜ)
|
|
||||||
println("α̅ₜ₋ₚ = ", α̅ₜ₋ₚ)
|
|
||||||
println("σ²ₜ = ", σ²ₜ)
|
|
||||||
σₜ = η * sqrt.(σ²ₜ)
|
σₜ = η * sqrt.(σ²ₜ)
|
||||||
|
|
||||||
# compute direction
|
# compute direction
|
||||||
Δₓ = sqrt.(1 .- α̅ₜ₋ₚ .- σₜ^2) .* ϵᵧ
|
Δₓ = sqrt.(1 .- α̅ₜ₋ₚ .- σₜ .^ 2) .* ϵ̂
|
||||||
|
|
||||||
# sample xₜ₋ₚ
|
# sample xₜ₋ₚ
|
||||||
ϵ = randn(Float32, size(xₜ))
|
ϵ = randn(Float32, size(xₜ))
|
||||||
|
@ -160,18 +144,37 @@ function get_prediction(
|
||||||
⎷β̅ₜ = _extract(scheduler.⎷β̅[t], xₜ)
|
⎷β̅ₜ = _extract(scheduler.⎷β̅[t], xₜ)
|
||||||
|
|
||||||
if prediction_type == EPSILON
|
if prediction_type == EPSILON
|
||||||
# arxiv:2006.11239 Eq. 15
|
|
||||||
# arxiv:2208.11970 Eq. 115
|
|
||||||
x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ
|
x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ
|
||||||
# elseif prediction_type == SAMPLE
|
ϵ̂ = ϵᵧ
|
||||||
# # arxiv:2208.11970 Eq. 99
|
elseif prediction_type == SAMPLE
|
||||||
# x̂₀ = ϵᵧ
|
x̂₀ = ϵᵧ
|
||||||
# elseif prediction_type == VELOCITY
|
ϵ̂ = (xₜ - ⎷α̅ₜ .* x̂₀) ./ ⎷β̅ₜ
|
||||||
# # arxiv:2202.00512 Eq. 31
|
elseif prediction_type == VELOCITY
|
||||||
# x̂₀ = ⎷α̅ₜ .* xₜ - ⎷β̅ₜ .* ϵᵧ
|
x̂₀ = ⎷α̅ₜ .* xₜ .- ⎷β̅ₜ .* ϵᵧ
|
||||||
|
ϵ̂ = ⎷α̅ₜ .* ϵᵧ .+ ⎷β̅ₜ .* xₜ
|
||||||
else
|
else
|
||||||
throw("unimplemented prediction type")
|
throw("unimplemented prediction type")
|
||||||
end
|
end
|
||||||
|
|
||||||
return x̂₀
|
return x̂₀, ϵ̂
|
||||||
|
end
|
||||||
|
|
||||||
|
function get_variance(
|
||||||
|
scheduler::DDIM,
|
||||||
|
xₜ::AbstractArray,
|
||||||
|
t::AbstractArray,
|
||||||
|
Δₜ::Integer,
|
||||||
|
)
|
||||||
|
α̅ = _extract(scheduler.α̅, xₜ)
|
||||||
|
α̅₋ₚ = ShiftedArray(scheduler.α̅, Δₜ, default=1.0f0)
|
||||||
|
α̅ₜ = _extract(α̅[t], xₜ)
|
||||||
|
α̅ₜ₋ₚ = _extract(α̅₋ₚ[t], xₜ)
|
||||||
|
β̅ = _extract(scheduler.β̅, xₜ)
|
||||||
|
β̅₋ₚ = ShiftedArray(scheduler.β̅, Δₜ, default=0.0f0)
|
||||||
|
β̅ₜ = _extract(β̅[t], xₜ)
|
||||||
|
β̅ₜ₋ₚ = _extract(β̅₋ₚ[t], xₜ)
|
||||||
|
|
||||||
|
σ²ₜ = (β̅ₜ₋ₚ ./ β̅ₜ) .* (1 .- α̅ₜ ./ α̅ₜ₋ₚ)
|
||||||
|
|
||||||
|
return σ²ₜ
|
||||||
end
|
end
|
||||||
|
|
|
@ -94,8 +94,9 @@ function reverse(
|
||||||
xₜ::AbstractArray,
|
xₜ::AbstractArray,
|
||||||
ϵᵧ::AbstractArray,
|
ϵᵧ::AbstractArray,
|
||||||
t::AbstractArray,
|
t::AbstractArray,
|
||||||
|
;
|
||||||
prediction_type::PredictionType=EPSILON,
|
prediction_type::PredictionType=EPSILON,
|
||||||
variance_type::VarianceType=FIXED_SMALL,
|
variance_type::VarianceType=FIXED_SMALL
|
||||||
)
|
)
|
||||||
# retreive scheduler variables at timesteps t
|
# retreive scheduler variables at timesteps t
|
||||||
βₜ = _extract(scheduler.β[t], xₜ)
|
βₜ = _extract(scheduler.β[t], xₜ)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import Diffusers: reverse, forward, get_velocity, DDPM, cosine_beta_schedule, VELOCITY
|
import Diffusers: reverse, forward, get_velocity, DDPM, cosine_beta_schedule, VELOCITY, EPSILON, SAMPLE
|
||||||
import Statistics: mean, std
|
import Statistics: mean, std
|
||||||
using Test
|
using Test
|
||||||
|
|
||||||
|
@ -7,24 +7,41 @@ using Test
|
||||||
@testset "check `reverse` correctness" begin
|
@testset "check `reverse` correctness" begin
|
||||||
T = 10
|
T = 10
|
||||||
batch_size = 8
|
batch_size = 8
|
||||||
size = 128
|
data_size = 128
|
||||||
|
|
||||||
# create a DDPM with a cosine beta schedule
|
# create a DDPM with a cosine beta schedule
|
||||||
ddpm = DDPM(cosine_beta_schedule(T))
|
for scheduler_type in [DDPM, DDIM]
|
||||||
|
|
||||||
|
scheduler = scheduler_type(cosine_beta_schedule(T))
|
||||||
|
@testset "Scheduler == $scheduler_type" begin
|
||||||
|
|
||||||
# create some dummy data
|
# create some dummy data
|
||||||
x₀ = ones(Float32, size, size, batch_size)
|
x₀ = ones(Float32, data_size, data_size, batch_size)
|
||||||
ϵ = randn(Float32, size, size, batch_size)
|
ϵ = randn(Float32, data_size, data_size, batch_size)
|
||||||
|
|
||||||
@testset "PredictionType == EPSILON" begin
|
@testset "PredictionType == EPSILON" begin
|
||||||
for t in 1:T
|
for t in 1:T
|
||||||
t = ones(UInt32, batch_size) .* t
|
t = ones(UInt32, batch_size) .* t
|
||||||
# get xₜ from forward diffusion process
|
# get xₜ from forward diffusion process
|
||||||
xₜ = forward(ddpm, x₀, ϵ, t)
|
xₜ = forward(scheduler, x₀, ϵ, t)
|
||||||
# suppose a model predicted ϵ perfectly
|
# suppose a model predicted ϵ perfectly
|
||||||
ϵᵧ = ϵ
|
ϵᵧ = ϵ
|
||||||
# use reverse diffusion process to retreive x̂₀
|
# use reverse diffusion process to retreive x̂₀
|
||||||
_, x̂₀ = reverse(ddpm, xₜ, ϵᵧ, t)
|
_, x̂₀ = reverse(scheduler, xₜ, ϵᵧ, t; prediction_type=EPSILON)
|
||||||
|
# test that we recover x₀
|
||||||
|
@test x̂₀ ≈ x₀
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "PredictionType == SAMPLE" begin
|
||||||
|
for t in 1:T
|
||||||
|
t = ones(UInt32, batch_size) .* t
|
||||||
|
# get xₜ from forward diffusion process
|
||||||
|
xₜ = forward(scheduler, x₀, ϵ, t)
|
||||||
|
# suppose a model predicted x₀ perfectly
|
||||||
|
ϵᵧ = x₀
|
||||||
|
# use reverse diffusion process to retreive x̂₀
|
||||||
|
_, x̂₀ = reverse(scheduler, xₜ, ϵᵧ, t; prediction_type=SAMPLE)
|
||||||
# test that we recover x₀
|
# test that we recover x₀
|
||||||
@test x̂₀ ≈ x₀
|
@test x̂₀ ≈ x₀
|
||||||
end
|
end
|
||||||
|
@ -34,13 +51,13 @@ using Test
|
||||||
for t in 1:T
|
for t in 1:T
|
||||||
t = ones(UInt32, batch_size) .* t
|
t = ones(UInt32, batch_size) .* t
|
||||||
# get xₜ from forward diffusion process
|
# get xₜ from forward diffusion process
|
||||||
xₜ = forward(ddpm, x₀, ϵ, t)
|
xₜ = forward(scheduler, x₀, ϵ, t)
|
||||||
# compute vₜ to train model
|
# compute vₜ to train model
|
||||||
vₜ = get_velocity(ddpm, x₀, ϵ, t)
|
vₜ = get_velocity(scheduler, x₀, ϵ, t)
|
||||||
# suppose a model predicted vₜ perfectly
|
# suppose a model predicted vₜ perfectly
|
||||||
vᵧ = vₜ
|
vᵧ = vₜ
|
||||||
# use reverse diffusion process to retreive x̂₀
|
# use reverse diffusion process to retreive x̂₀
|
||||||
_, x̂₀ = Diffusers.reverse(ddpm, xₜ, vₜ, t, VELOCITY)
|
_, x̂₀ = Diffusers.reverse(scheduler, xₜ, vₜ, t; prediction_type=VELOCITY)
|
||||||
# test that we recover x₀
|
# test that we recover x₀
|
||||||
@test x̂₀ ≈ x₀
|
@test x̂₀ ≈ x₀
|
||||||
end
|
end
|
||||||
|
@ -48,6 +65,10 @@ using Test
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
@testset "check `forward` terminal SNR" begin
|
@testset "check `forward` terminal SNR" begin
|
||||||
T = 10
|
T = 10
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
|
Loading…
Reference in a new issue