From 5ec22c47b2b5221fac401a93d1404b3e094a7562 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Mon, 14 Aug 2023 16:32:07 +0200 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20(Schedulers)=20add=20varia?= =?UTF-8?q?nce=20and=20prediction=20types?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Project.toml | 2 + src/Schedulers/Abstract.jl | 25 ++++++ src/Schedulers/DDPM.jl | 153 +++++++++++++++++++++++++++---------- 3 files changed, 138 insertions(+), 42 deletions(-) diff --git a/Project.toml b/Project.toml index 6590e19..7008c29 100644 --- a/Project.toml +++ b/Project.toml @@ -5,9 +5,11 @@ version = "0.1.0" [deps] NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +ShiftedArrays = "1277b4bf-5013-50f5-be3d-901d8477a67a" [compat] NNlib = "0.9" +ShiftedArrays = "2.0" julia = "1" [extras] diff --git a/src/Schedulers/Abstract.jl b/src/Schedulers/Abstract.jl index a419fec..fe4f51e 100644 --- a/src/Schedulers/Abstract.jl +++ b/src/Schedulers/Abstract.jl @@ -4,6 +4,9 @@ Abstract type for schedulers. """ abstract type Scheduler end +@enum VarianceType FIXED_SMALL FIXED_SMALL_LOG FIXED_LARGE FIXED_LARGE_LOG LEARNED +@enum PredictionType EPSILON SAMPLE VELOCITY + """ Add noise to clean data using the forward diffusion process. @@ -42,3 +45,25 @@ function reverse( ϵᵧ::AbstractArray, t::AbstractArray, ) end + +""" +Compute the velocity of the diffusion process. + +## Input + * `scheduler::Scheduler`: scheduler to use + * `x₀::AbstractArray`: clean data to add noise to + * `ϵ::AbstractArray`: noise to add to clean data + * `t::AbstractArray`: timesteps used to weight the noise + +## Output + * `vₜ::AbstractArray`: velocity at the given timesteps + +## References + * [[2202.00512] Progressive Distillation for Fast Sampling of Diffusion Models](https://arxiv.org/abs/2202.00512) (Ann. D) +""" +function get_velocity( + scheduler::Scheduler, + x₀::AbstractArray, + ϵ::AbstractArray, + t::AbstractArray, +) end diff --git a/src/Schedulers/DDPM.jl b/src/Schedulers/DDPM.jl index 59dcd22..e78d40d 100644 --- a/src/Schedulers/DDPM.jl +++ b/src/Schedulers/DDPM.jl @@ -1,5 +1,18 @@ include("Abstract.jl") +using ShiftedArrays + +function _extract( + target::AbstractArray, + reference::AbstractArray, +) + new_size = tuple( + fill(1, ndims(reference) - 1)..., + size(reference, ndims(reference)) + ) + return reshape(target, new_size) +end + """ Denoising Diffusion Probabilistic Models (DDPM) scheduler. @@ -9,8 +22,8 @@ Denoising Diffusion Probabilistic Models (DDPM) scheduler. struct DDPM{V<:AbstractVector} <: Scheduler T::Integer # length of markov chain - β::V # beta variance schedule α::V # 1 - beta + β::V # beta variance schedule ⎷α::V # square root of α ⎷β::V # square root of β @@ -39,7 +52,7 @@ function DDPM(β::AbstractVector) α̅ = cumprod(α) β̅ = 1 .- α̅ - α̅₋₁ = [1, (α̅[1:end-1])...] + α̅₋₁ = ShiftedArray(α̅, 1, default=1) β̅₋₁ = 1 .- α̅₋₁ ⎷α̅ = sqrt.(α̅) @@ -50,18 +63,12 @@ function DDPM(β::AbstractVector) DDPM{typeof(β)}( T, - β, - α, - ⎷α, - ⎷β, - α̅, - β̅, - α̅₋₁, - β̅₋₁, - ⎷α̅, - ⎷β̅, - ⎷α̅₋₁, - ⎷β̅₋₁, + α, β, + ⎷α, ⎷β, + α̅, β̅, + α̅₋₁, β̅₋₁, + ⎷α̅, ⎷β̅, + ⎷α̅₋₁, ⎷β̅₋₁, ) end @@ -72,12 +79,8 @@ function forward( t::AbstractArray, ) # retreive scheduler variables at timesteps t - reshape_size = tuple( - fill(1, ndims(x₀) - 1)..., - size(t, 1) - ) - ⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size) - ⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size) + ⎷α̅ₜ = _extract(scheduler.⎷α̅[t], x₀) + ⎷β̅ₜ = _extract(scheduler.⎷β̅[t], x₀) # noisify clean data # arxiv:2006.11239 Eq. 4 @@ -91,37 +94,103 @@ function reverse( xₜ::AbstractArray, ϵᵧ::AbstractArray, t::AbstractArray, + prediction_type::PredictionType=EPSILON, + variance_type::VarianceType=FIXED_SMALL, ) # retreive scheduler variables at timesteps t - reshape_size = tuple( - fill(1, ndims(xₜ) - 1)..., - size(t, 1) - ) - βₜ = reshape(scheduler.β[t], reshape_size) - β̅ₜ = reshape(scheduler.β̅[t], reshape_size) - β̅ₜ₋₁ = reshape(scheduler.β̅₋₁[t], reshape_size) - ⎷αₜ = reshape(scheduler.⎷α[t], reshape_size) - ⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size) - ⎷α̅ₜ₋₁ = reshape(scheduler.⎷α̅₋₁[t], reshape_size) - ⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size) + βₜ = _extract(scheduler.β[t], xₜ) + β̅ₜ = _extract(scheduler.β̅[t], xₜ) + β̅ₜ₋₁ = _extract(scheduler.β̅₋₁[t], xₜ) + ⎷αₜ = _extract(scheduler.⎷α[t], xₜ) + ⎷α̅ₜ₋₁ = _extract(scheduler.⎷α̅₋₁[t], xₜ) - # compute predicted previous sample x̂₀ - # arxiv:2006.11239 Eq. 15 - # arxiv:2208.11970 Eq. 115 - x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ + # compute x₀ (approximation) + x̂₀ = get_prediction(scheduler, prediction_type, xₜ, ϵᵧ, t) - # compute predicted previous sample μ̃ₜ + # compute μₜ (approximation) # arxiv:2006.11239 Eq. 7 # arxiv:2208.11970 Eq. 84 λ₀ = ⎷α̅ₜ₋₁ .* βₜ ./ β̅ₜ - λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ # TODO: this could be stored in the scheduler + λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ μ̃ₜ = λ₀ .* x̂₀ + λₜ .* xₜ - # sample predicted previous sample xₜ₋₁ - # arxiv:2006.11239 Eq. 6 - # arxiv:2208.11970 Eq. 70 - σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler - xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ)) + # compute σ² (exact) + σ²ₜ = get_variance(scheduler, variance_type, xₜ, t) + + # sample xₜ₋₁ using μₜ and σ² + σₜ = sqrt.(σ²ₜ) + ϵ = randn(size(ϵᵧ)) + xₜ₋₁ = μ̃ₜ + σₜ .* ϵ return xₜ₋₁, x̂₀ end + +function get_velocity( + scheduler::DDPM, + x₀::AbstractArray, + ϵ::AbstractArray, + t::AbstractArray, +) + ⎷α̅ₜ = _extract(scheduler.⎷α̅[t], x₀) + ⎷β̅ₜ = _extract(scheduler.⎷β̅[t], x₀) + + vₜ = ⎷α̅ₜ .* ϵ - ⎷β̅ₜ .* x₀ + + return vₜ +end + +function get_prediction( + scheduler::DDPM, + prediction_type::PredictionType, + xₜ::AbstractArray, + ϵᵧ::AbstractArray, + t::AbstractArray, +) + ⎷α̅ₜ = _extract(scheduler.⎷α̅[t], xₜ) + ⎷β̅ₜ = _extract(scheduler.⎷β̅[t], xₜ) + + if prediction_type == EPSILON + # arxiv:2006.11239 Eq. 15 + # arxiv:2208.11970 Eq. 115 + x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ + elseif prediction_type == SAMPLE + # arxiv:2208.11970 Eq. 99 + x̂₀ = ϵᵧ + elseif prediction_type == VELOCITY + # arxiv:2202.00512 Eq. 31 + x̂₀ = ⎷α̅ₜ .* xₜ - ⎷β̅ₜ .* ϵᵧ + else + throw("unimplemented prediction type") + end + + return x̂₀ +end + +function get_variance( + scheduler::DDPM, + variance_type::VarianceType, + xₜ::AbstractArray, + t::AbstractArray, +) + βₜ = _extract(scheduler.β[t], xₜ) + β̅ₜ = _extract(scheduler.β̅[t], xₜ) + β̅ₜ₋₁ = _extract(scheduler.β̅₋₁[t], xₜ) + + if variance_type == FIXED_SMALL + # arxiv:2006.11239 Eq. 6 + # arxiv:2208.11970 Eq. 70 + σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ + elseif variance_type == FIXED_SMALL_LOG + σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ + σₜ = log.(σₜ) + elseif variance_type == FIXED_LARGE + σₜ = βₜ + elseif variance_type == FIXED_LARGE_LOG + σₜ = βₜ + σₜ = log.(σₜ) + else + throw("unimplemented variance type") + end + + return σₜ +end