mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-10-17 23:26:19 +00:00
🧪 (DDIM) add reverse process correctness tests
This commit is contained in:
parent
02d261f516
commit
7427828eda
|
@ -94,39 +94,23 @@ function reverse(
|
|||
xₜ::AbstractArray,
|
||||
ϵᵧ::AbstractArray,
|
||||
t::AbstractArray,
|
||||
;
|
||||
η::Real=0.0f0,
|
||||
prediction_type::PredictionType=EPSILON,
|
||||
prediction_type::PredictionType=EPSILON
|
||||
)
|
||||
Δₜ = 1
|
||||
α̅ = _extract(scheduler.α̅, xₜ)
|
||||
α̅₋ₚ = ShiftedArray(scheduler.α̅, Δₜ, default=1.0f0)
|
||||
α̅ₜ = _extract(α̅[t], xₜ)
|
||||
α̅ₜ₋ₚ = _extract(α̅₋ₚ[t], xₜ)
|
||||
β̅ = _extract(scheduler.β̅, xₜ)
|
||||
β̅₋ₚ = ShiftedArray(scheduler.β̅, Δₜ, default=0.0f0)
|
||||
β̅ₜ = _extract(β̅[t], xₜ)
|
||||
β̅ₜ₋ₚ = _extract(β̅₋ₚ[t], xₜ)
|
||||
|
||||
println("α̅ₜ = ", α̅ₜ)
|
||||
println("α̅ₜ₋ₚ = ", α̅ₜ₋ₚ)
|
||||
|
||||
# compute x₀ (approximation)
|
||||
x̂₀ = get_prediction(scheduler, prediction_type, xₜ, ϵᵧ, t)
|
||||
|
||||
# clip between -1 and 1
|
||||
x̂₀ = clamp.(x̂₀, -3, 3)
|
||||
x̂₀, ϵ̂ = get_prediction(scheduler, prediction_type, xₜ, ϵᵧ, t)
|
||||
|
||||
# compute σ (exact)
|
||||
σ²ₜ = (β̅ₜ₋ₚ ./ β̅ₜ) .* (1 .- α̅ₜ ./ α̅ₜ₋ₚ)
|
||||
println("β̅ₜ₋ₚ = ", β̅ₜ₋ₚ)
|
||||
println("β̅ₜ = ", β̅ₜ)
|
||||
println("α̅ₜ = ", α̅ₜ)
|
||||
println("α̅ₜ₋ₚ = ", α̅ₜ₋ₚ)
|
||||
println("σ²ₜ = ", σ²ₜ)
|
||||
σ²ₜ = get_variance(scheduler, xₜ, t, Δₜ)
|
||||
σₜ = η * sqrt.(σ²ₜ)
|
||||
|
||||
# compute direction
|
||||
Δₓ = sqrt.(1 .- α̅ₜ₋ₚ .- σₜ^2) .* ϵᵧ
|
||||
Δₓ = sqrt.(1 .- α̅ₜ₋ₚ .- σₜ .^ 2) .* ϵ̂
|
||||
|
||||
# sample xₜ₋ₚ
|
||||
ϵ = randn(Float32, size(xₜ))
|
||||
|
@ -160,18 +144,37 @@ function get_prediction(
|
|||
⎷β̅ₜ = _extract(scheduler.⎷β̅[t], xₜ)
|
||||
|
||||
if prediction_type == EPSILON
|
||||
# arxiv:2006.11239 Eq. 15
|
||||
# arxiv:2208.11970 Eq. 115
|
||||
x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ
|
||||
# elseif prediction_type == SAMPLE
|
||||
# # arxiv:2208.11970 Eq. 99
|
||||
# x̂₀ = ϵᵧ
|
||||
# elseif prediction_type == VELOCITY
|
||||
# # arxiv:2202.00512 Eq. 31
|
||||
# x̂₀ = ⎷α̅ₜ .* xₜ - ⎷β̅ₜ .* ϵᵧ
|
||||
ϵ̂ = ϵᵧ
|
||||
elseif prediction_type == SAMPLE
|
||||
x̂₀ = ϵᵧ
|
||||
ϵ̂ = (xₜ - ⎷α̅ₜ .* x̂₀) ./ ⎷β̅ₜ
|
||||
elseif prediction_type == VELOCITY
|
||||
x̂₀ = ⎷α̅ₜ .* xₜ .- ⎷β̅ₜ .* ϵᵧ
|
||||
ϵ̂ = ⎷α̅ₜ .* ϵᵧ .+ ⎷β̅ₜ .* xₜ
|
||||
else
|
||||
throw("unimplemented prediction type")
|
||||
end
|
||||
|
||||
return x̂₀
|
||||
return x̂₀, ϵ̂
|
||||
end
|
||||
|
||||
function get_variance(
|
||||
scheduler::DDIM,
|
||||
xₜ::AbstractArray,
|
||||
t::AbstractArray,
|
||||
Δₜ::Integer,
|
||||
)
|
||||
α̅ = _extract(scheduler.α̅, xₜ)
|
||||
α̅₋ₚ = ShiftedArray(scheduler.α̅, Δₜ, default=1.0f0)
|
||||
α̅ₜ = _extract(α̅[t], xₜ)
|
||||
α̅ₜ₋ₚ = _extract(α̅₋ₚ[t], xₜ)
|
||||
β̅ = _extract(scheduler.β̅, xₜ)
|
||||
β̅₋ₚ = ShiftedArray(scheduler.β̅, Δₜ, default=0.0f0)
|
||||
β̅ₜ = _extract(β̅[t], xₜ)
|
||||
β̅ₜ₋ₚ = _extract(β̅₋ₚ[t], xₜ)
|
||||
|
||||
σ²ₜ = (β̅ₜ₋ₚ ./ β̅ₜ) .* (1 .- α̅ₜ ./ α̅ₜ₋ₚ)
|
||||
|
||||
return σ²ₜ
|
||||
end
|
||||
|
|
|
@ -94,8 +94,9 @@ function reverse(
|
|||
xₜ::AbstractArray,
|
||||
ϵᵧ::AbstractArray,
|
||||
t::AbstractArray,
|
||||
;
|
||||
prediction_type::PredictionType=EPSILON,
|
||||
variance_type::VarianceType=FIXED_SMALL,
|
||||
variance_type::VarianceType=FIXED_SMALL
|
||||
)
|
||||
# retreive scheduler variables at timesteps t
|
||||
βₜ = _extract(scheduler.β[t], xₜ)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import Diffusers: reverse, forward, get_velocity, DDPM, cosine_beta_schedule, VELOCITY
|
||||
import Diffusers: reverse, forward, get_velocity, DDPM, cosine_beta_schedule, VELOCITY, EPSILON, SAMPLE
|
||||
import Statistics: mean, std
|
||||
using Test
|
||||
|
||||
|
@ -7,43 +7,64 @@ using Test
|
|||
@testset "check `reverse` correctness" begin
|
||||
T = 10
|
||||
batch_size = 8
|
||||
size = 128
|
||||
data_size = 128
|
||||
|
||||
# create a DDPM with a cosine beta schedule
|
||||
ddpm = DDPM(cosine_beta_schedule(T))
|
||||
for scheduler_type in [DDPM, DDIM]
|
||||
|
||||
# create some dummy data
|
||||
x₀ = ones(Float32, size, size, batch_size)
|
||||
ϵ = randn(Float32, size, size, batch_size)
|
||||
scheduler = scheduler_type(cosine_beta_schedule(T))
|
||||
@testset "Scheduler == $scheduler_type" begin
|
||||
|
||||
# create some dummy data
|
||||
x₀ = ones(Float32, data_size, data_size, batch_size)
|
||||
ϵ = randn(Float32, data_size, data_size, batch_size)
|
||||
|
||||
@testset "PredictionType == EPSILON" begin
|
||||
for t in 1:T
|
||||
t = ones(UInt32, batch_size) .* t
|
||||
# get xₜ from forward diffusion process
|
||||
xₜ = forward(scheduler, x₀, ϵ, t)
|
||||
# suppose a model predicted ϵ perfectly
|
||||
ϵᵧ = ϵ
|
||||
# use reverse diffusion process to retreive x̂₀
|
||||
_, x̂₀ = reverse(scheduler, xₜ, ϵᵧ, t; prediction_type=EPSILON)
|
||||
# test that we recover x₀
|
||||
@test x̂₀ ≈ x₀
|
||||
end
|
||||
end
|
||||
|
||||
@testset "PredictionType == SAMPLE" begin
|
||||
for t in 1:T
|
||||
t = ones(UInt32, batch_size) .* t
|
||||
# get xₜ from forward diffusion process
|
||||
xₜ = forward(scheduler, x₀, ϵ, t)
|
||||
# suppose a model predicted x₀ perfectly
|
||||
ϵᵧ = x₀
|
||||
# use reverse diffusion process to retreive x̂₀
|
||||
_, x̂₀ = reverse(scheduler, xₜ, ϵᵧ, t; prediction_type=SAMPLE)
|
||||
# test that we recover x₀
|
||||
@test x̂₀ ≈ x₀
|
||||
end
|
||||
end
|
||||
|
||||
@testset "PredictionType == VELOCITY" begin
|
||||
for t in 1:T
|
||||
t = ones(UInt32, batch_size) .* t
|
||||
# get xₜ from forward diffusion process
|
||||
xₜ = forward(scheduler, x₀, ϵ, t)
|
||||
# compute vₜ to train model
|
||||
vₜ = get_velocity(scheduler, x₀, ϵ, t)
|
||||
# suppose a model predicted vₜ perfectly
|
||||
vᵧ = vₜ
|
||||
# use reverse diffusion process to retreive x̂₀
|
||||
_, x̂₀ = Diffusers.reverse(scheduler, xₜ, vₜ, t; prediction_type=VELOCITY)
|
||||
# test that we recover x₀
|
||||
@test x̂₀ ≈ x₀
|
||||
end
|
||||
end
|
||||
|
||||
@testset "PredictionType == EPSILON" begin
|
||||
for t in 1:T
|
||||
t = ones(UInt32, batch_size) .* t
|
||||
# get xₜ from forward diffusion process
|
||||
xₜ = forward(ddpm, x₀, ϵ, t)
|
||||
# suppose a model predicted ϵ perfectly
|
||||
ϵᵧ = ϵ
|
||||
# use reverse diffusion process to retreive x̂₀
|
||||
_, x̂₀ = reverse(ddpm, xₜ, ϵᵧ, t)
|
||||
# test that we recover x₀
|
||||
@test x̂₀ ≈ x₀
|
||||
end
|
||||
end
|
||||
|
||||
@testset "PredictionType == VELOCITY" begin
|
||||
for t in 1:T
|
||||
t = ones(UInt32, batch_size) .* t
|
||||
# get xₜ from forward diffusion process
|
||||
xₜ = forward(ddpm, x₀, ϵ, t)
|
||||
# compute vₜ to train model
|
||||
vₜ = get_velocity(ddpm, x₀, ϵ, t)
|
||||
# suppose a model predicted vₜ perfectly
|
||||
vᵧ = vₜ
|
||||
# use reverse diffusion process to retreive x̂₀
|
||||
_, x̂₀ = Diffusers.reverse(ddpm, xₜ, vₜ, t, VELOCITY)
|
||||
# test that we recover x₀
|
||||
@test x̂₀ ≈ x₀
|
||||
end
|
||||
end
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue