diff --git a/src/Schedulers/DDIM.jl b/src/Schedulers/DDIM.jl index 32f258a..b2bdb3f 100644 --- a/src/Schedulers/DDIM.jl +++ b/src/Schedulers/DDIM.jl @@ -94,39 +94,23 @@ function reverse( xₜ::AbstractArray, ϵᵧ::AbstractArray, t::AbstractArray, + ; η::Real=0.0f0, - prediction_type::PredictionType=EPSILON, + prediction_type::PredictionType=EPSILON ) Δₜ = 1 - α̅ = _extract(scheduler.α̅, xₜ) α̅₋ₚ = ShiftedArray(scheduler.α̅, Δₜ, default=1.0f0) - α̅ₜ = _extract(α̅[t], xₜ) α̅ₜ₋ₚ = _extract(α̅₋ₚ[t], xₜ) - β̅ = _extract(scheduler.β̅, xₜ) - β̅₋ₚ = ShiftedArray(scheduler.β̅, Δₜ, default=0.0f0) - β̅ₜ = _extract(β̅[t], xₜ) - β̅ₜ₋ₚ = _extract(β̅₋ₚ[t], xₜ) - - println("α̅ₜ = ", α̅ₜ) - println("α̅ₜ₋ₚ = ", α̅ₜ₋ₚ) # compute x₀ (approximation) - x̂₀ = get_prediction(scheduler, prediction_type, xₜ, ϵᵧ, t) - - # clip between -1 and 1 - x̂₀ = clamp.(x̂₀, -3, 3) + x̂₀, ϵ̂ = get_prediction(scheduler, prediction_type, xₜ, ϵᵧ, t) # compute σ (exact) - σ²ₜ = (β̅ₜ₋ₚ ./ β̅ₜ) .* (1 .- α̅ₜ ./ α̅ₜ₋ₚ) - println("β̅ₜ₋ₚ = ", β̅ₜ₋ₚ) - println("β̅ₜ = ", β̅ₜ) - println("α̅ₜ = ", α̅ₜ) - println("α̅ₜ₋ₚ = ", α̅ₜ₋ₚ) - println("σ²ₜ = ", σ²ₜ) + σ²ₜ = get_variance(scheduler, xₜ, t, Δₜ) σₜ = η * sqrt.(σ²ₜ) # compute direction - Δₓ = sqrt.(1 .- α̅ₜ₋ₚ .- σₜ^2) .* ϵᵧ + Δₓ = sqrt.(1 .- α̅ₜ₋ₚ .- σₜ .^ 2) .* ϵ̂ # sample xₜ₋ₚ ϵ = randn(Float32, size(xₜ)) @@ -160,18 +144,37 @@ function get_prediction( ⎷β̅ₜ = _extract(scheduler.⎷β̅[t], xₜ) if prediction_type == EPSILON - # arxiv:2006.11239 Eq. 15 - # arxiv:2208.11970 Eq. 115 x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ - # elseif prediction_type == SAMPLE - # # arxiv:2208.11970 Eq. 99 - # x̂₀ = ϵᵧ - # elseif prediction_type == VELOCITY - # # arxiv:2202.00512 Eq. 31 - # x̂₀ = ⎷α̅ₜ .* xₜ - ⎷β̅ₜ .* ϵᵧ + ϵ̂ = ϵᵧ + elseif prediction_type == SAMPLE + x̂₀ = ϵᵧ + ϵ̂ = (xₜ - ⎷α̅ₜ .* x̂₀) ./ ⎷β̅ₜ + elseif prediction_type == VELOCITY + x̂₀ = ⎷α̅ₜ .* xₜ .- ⎷β̅ₜ .* ϵᵧ + ϵ̂ = ⎷α̅ₜ .* ϵᵧ .+ ⎷β̅ₜ .* xₜ else throw("unimplemented prediction type") end - return x̂₀ + return x̂₀, ϵ̂ +end + +function get_variance( + scheduler::DDIM, + xₜ::AbstractArray, + t::AbstractArray, + Δₜ::Integer, +) + α̅ = _extract(scheduler.α̅, xₜ) + α̅₋ₚ = ShiftedArray(scheduler.α̅, Δₜ, default=1.0f0) + α̅ₜ = _extract(α̅[t], xₜ) + α̅ₜ₋ₚ = _extract(α̅₋ₚ[t], xₜ) + β̅ = _extract(scheduler.β̅, xₜ) + β̅₋ₚ = ShiftedArray(scheduler.β̅, Δₜ, default=0.0f0) + β̅ₜ = _extract(β̅[t], xₜ) + β̅ₜ₋ₚ = _extract(β̅₋ₚ[t], xₜ) + + σ²ₜ = (β̅ₜ₋ₚ ./ β̅ₜ) .* (1 .- α̅ₜ ./ α̅ₜ₋ₚ) + + return σ²ₜ end diff --git a/src/Schedulers/DDPM.jl b/src/Schedulers/DDPM.jl index 902462e..94d8c36 100644 --- a/src/Schedulers/DDPM.jl +++ b/src/Schedulers/DDPM.jl @@ -94,8 +94,9 @@ function reverse( xₜ::AbstractArray, ϵᵧ::AbstractArray, t::AbstractArray, + ; prediction_type::PredictionType=EPSILON, - variance_type::VarianceType=FIXED_SMALL, + variance_type::VarianceType=FIXED_SMALL ) # retreive scheduler variables at timesteps t βₜ = _extract(scheduler.β[t], xₜ) diff --git a/test/Schedulers.jl b/test/Schedulers.jl index ea15b8f..c70bb00 100644 --- a/test/Schedulers.jl +++ b/test/Schedulers.jl @@ -1,4 +1,4 @@ -import Diffusers: reverse, forward, get_velocity, DDPM, cosine_beta_schedule, VELOCITY +import Diffusers: reverse, forward, get_velocity, DDPM, cosine_beta_schedule, VELOCITY, EPSILON, SAMPLE import Statistics: mean, std using Test @@ -7,43 +7,64 @@ using Test @testset "check `reverse` correctness" begin T = 10 batch_size = 8 - size = 128 + data_size = 128 # create a DDPM with a cosine beta schedule - ddpm = DDPM(cosine_beta_schedule(T)) + for scheduler_type in [DDPM, DDIM] - # create some dummy data - x₀ = ones(Float32, size, size, batch_size) - ϵ = randn(Float32, size, size, batch_size) + scheduler = scheduler_type(cosine_beta_schedule(T)) + @testset "Scheduler == $scheduler_type" begin + + # create some dummy data + x₀ = ones(Float32, data_size, data_size, batch_size) + ϵ = randn(Float32, data_size, data_size, batch_size) + + @testset "PredictionType == EPSILON" begin + for t in 1:T + t = ones(UInt32, batch_size) .* t + # get xₜ from forward diffusion process + xₜ = forward(scheduler, x₀, ϵ, t) + # suppose a model predicted ϵ perfectly + ϵᵧ = ϵ + # use reverse diffusion process to retreive x̂₀ + _, x̂₀ = reverse(scheduler, xₜ, ϵᵧ, t; prediction_type=EPSILON) + # test that we recover x₀ + @test x̂₀ ≈ x₀ + end + end + + @testset "PredictionType == SAMPLE" begin + for t in 1:T + t = ones(UInt32, batch_size) .* t + # get xₜ from forward diffusion process + xₜ = forward(scheduler, x₀, ϵ, t) + # suppose a model predicted x₀ perfectly + ϵᵧ = x₀ + # use reverse diffusion process to retreive x̂₀ + _, x̂₀ = reverse(scheduler, xₜ, ϵᵧ, t; prediction_type=SAMPLE) + # test that we recover x₀ + @test x̂₀ ≈ x₀ + end + end + + @testset "PredictionType == VELOCITY" begin + for t in 1:T + t = ones(UInt32, batch_size) .* t + # get xₜ from forward diffusion process + xₜ = forward(scheduler, x₀, ϵ, t) + # compute vₜ to train model + vₜ = get_velocity(scheduler, x₀, ϵ, t) + # suppose a model predicted vₜ perfectly + vᵧ = vₜ + # use reverse diffusion process to retreive x̂₀ + _, x̂₀ = Diffusers.reverse(scheduler, xₜ, vₜ, t; prediction_type=VELOCITY) + # test that we recover x₀ + @test x̂₀ ≈ x₀ + end + end - @testset "PredictionType == EPSILON" begin - for t in 1:T - t = ones(UInt32, batch_size) .* t - # get xₜ from forward diffusion process - xₜ = forward(ddpm, x₀, ϵ, t) - # suppose a model predicted ϵ perfectly - ϵᵧ = ϵ - # use reverse diffusion process to retreive x̂₀ - _, x̂₀ = reverse(ddpm, xₜ, ϵᵧ, t) - # test that we recover x₀ - @test x̂₀ ≈ x₀ end - end - @testset "PredictionType == VELOCITY" begin - for t in 1:T - t = ones(UInt32, batch_size) .* t - # get xₜ from forward diffusion process - xₜ = forward(ddpm, x₀, ϵ, t) - # compute vₜ to train model - vₜ = get_velocity(ddpm, x₀, ϵ, t) - # suppose a model predicted vₜ perfectly - vᵧ = vₜ - # use reverse diffusion process to retreive x̂₀ - _, x̂₀ = Diffusers.reverse(ddpm, xₜ, vₜ, t, VELOCITY) - # test that we recover x₀ - @test x̂₀ ≈ x₀ - end end end