♻️ (Schedulers) add variance and prediction types

This commit is contained in:
Laureηt 2023-08-14 16:32:07 +02:00
parent 1e5ceb5140
commit 5ec22c47b2
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
3 changed files with 138 additions and 42 deletions

View file

@ -5,9 +5,11 @@ version = "0.1.0"
[deps]
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
ShiftedArrays = "1277b4bf-5013-50f5-be3d-901d8477a67a"
[compat]
NNlib = "0.9"
ShiftedArrays = "2.0"
julia = "1"
[extras]

View file

@ -4,6 +4,9 @@ Abstract type for schedulers.
"""
abstract type Scheduler end
@enum VarianceType FIXED_SMALL FIXED_SMALL_LOG FIXED_LARGE FIXED_LARGE_LOG LEARNED
@enum PredictionType EPSILON SAMPLE VELOCITY
"""
Add noise to clean data using the forward diffusion process.
@ -42,3 +45,25 @@ function reverse(
ϵᵧ::AbstractArray,
t::AbstractArray,
) end
"""
Compute the velocity of the diffusion process.
## Input
* `scheduler::Scheduler`: scheduler to use
* `x₀::AbstractArray`: clean data to add noise to
* `ϵ::AbstractArray`: noise to add to clean data
* `t::AbstractArray`: timesteps used to weight the noise
## Output
* `vₜ::AbstractArray`: velocity at the given timesteps
## References
* [[2202.00512] Progressive Distillation for Fast Sampling of Diffusion Models](https://arxiv.org/abs/2202.00512) (Ann. D)
"""
function get_velocity(
scheduler::Scheduler,
x₀::AbstractArray,
ϵ::AbstractArray,
t::AbstractArray,
) end

View file

@ -1,5 +1,18 @@
include("Abstract.jl")
using ShiftedArrays
function _extract(
target::AbstractArray,
reference::AbstractArray,
)
new_size = tuple(
fill(1, ndims(reference) - 1)...,
size(reference, ndims(reference))
)
return reshape(target, new_size)
end
"""
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
@ -9,8 +22,8 @@ Denoising Diffusion Probabilistic Models (DDPM) scheduler.
struct DDPM{V<:AbstractVector} <: Scheduler
T::Integer # length of markov chain
β::V # beta variance schedule
α::V # 1 - beta
β::V # beta variance schedule
⎷α::V # square root of α
⎷β::V # square root of β
@ -39,7 +52,7 @@ function DDPM(β::AbstractVector)
α̅ = cumprod(α)
β̅ = 1 .- α̅
α̅₋₁ = [1, (α̅[1:end-1])...]
α̅₋₁ = ShiftedArray(α̅, 1, default=1)
β̅₋₁ = 1 .- α̅₋₁
⎷α̅ = sqrt.(α̅)
@ -50,18 +63,12 @@ function DDPM(β::AbstractVector)
DDPM{typeof(β)}(
T,
β,
α,
⎷α,
⎷β,
α̅,
β̅,
α̅₋₁,
β̅₋₁,
⎷α̅,
⎷β̅,
⎷α̅₋₁,
⎷β̅₋₁,
α, β,
⎷α, ⎷β,
α̅, β̅,
α̅₋₁, β̅₋₁,
⎷α̅, ⎷β̅,
⎷α̅₋₁, ⎷β̅₋₁,
)
end
@ -72,12 +79,8 @@ function forward(
t::AbstractArray,
)
# retreive scheduler variables at timesteps t
reshape_size = tuple(
fill(1, ndims(x₀) - 1)...,
size(t, 1)
)
⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size)
⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size)
⎷α̅ₜ = _extract(scheduler.⎷α̅[t], x₀)
⎷β̅ₜ = _extract(scheduler.⎷β̅[t], x₀)
# noisify clean data
# arxiv:2006.11239 Eq. 4
@ -91,37 +94,103 @@ function reverse(
xₜ::AbstractArray,
ϵᵧ::AbstractArray,
t::AbstractArray,
prediction_type::PredictionType=EPSILON,
variance_type::VarianceType=FIXED_SMALL,
)
# retreive scheduler variables at timesteps t
reshape_size = tuple(
fill(1, ndims(xₜ) - 1)...,
size(t, 1)
)
βₜ = reshape(scheduler.β[t], reshape_size)
β̅ₜ = reshape(scheduler.β̅[t], reshape_size)
β̅ₜ₋₁ = reshape(scheduler.β̅₋₁[t], reshape_size)
⎷αₜ = reshape(scheduler.⎷α[t], reshape_size)
⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size)
⎷α̅ₜ₋₁ = reshape(scheduler.⎷α̅₋₁[t], reshape_size)
⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size)
βₜ = _extract(scheduler.β[t], xₜ)
β̅ₜ = _extract(scheduler.β̅[t], xₜ)
β̅ₜ₋₁ = _extract(scheduler.β̅₋₁[t], xₜ)
⎷αₜ = _extract(scheduler.⎷α[t], xₜ)
⎷α̅ₜ₋₁ = _extract(scheduler.⎷α̅₋₁[t], xₜ)
# compute predicted previous sample x̂₀
# arxiv:2006.11239 Eq. 15
# arxiv:2208.11970 Eq. 115
x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ
# compute x₀ (approximation)
x̂₀ = get_prediction(scheduler, prediction_type, xₜ, ϵᵧ, t)
# compute predicted previous sample μ̃ₜ
# compute μₜ (approximation)
# arxiv:2006.11239 Eq. 7
# arxiv:2208.11970 Eq. 84
λ₀ = ⎷α̅ₜ₋₁ .* βₜ ./ β̅ₜ
λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ # TODO: this could be stored in the scheduler
λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ
μ̃ₜ = λ₀ .* x̂₀ + λₜ .* xₜ
# sample predicted previous sample xₜ₋₁
# arxiv:2006.11239 Eq. 6
# arxiv:2208.11970 Eq. 70
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler
xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ))
# compute σ² (exact)
σ²ₜ = get_variance(scheduler, variance_type, xₜ, t)
# sample xₜ₋₁ using μₜ and σ²
σₜ = sqrt.(σ²ₜ)
ϵ = randn(size(ϵᵧ))
xₜ₋₁ = μ̃ₜ + σₜ .* ϵ
return xₜ₋₁, x̂₀
end
function get_velocity(
scheduler::DDPM,
x₀::AbstractArray,
ϵ::AbstractArray,
t::AbstractArray,
)
⎷α̅ₜ = _extract(scheduler.⎷α̅[t], x₀)
⎷β̅ₜ = _extract(scheduler.⎷β̅[t], x₀)
vₜ = ⎷α̅ₜ .* ϵ - ⎷β̅ₜ .* x₀
return vₜ
end
function get_prediction(
scheduler::DDPM,
prediction_type::PredictionType,
xₜ::AbstractArray,
ϵᵧ::AbstractArray,
t::AbstractArray,
)
⎷α̅ₜ = _extract(scheduler.⎷α̅[t], xₜ)
⎷β̅ₜ = _extract(scheduler.⎷β̅[t], xₜ)
if prediction_type == EPSILON
# arxiv:2006.11239 Eq. 15
# arxiv:2208.11970 Eq. 115
x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ
elseif prediction_type == SAMPLE
# arxiv:2208.11970 Eq. 99
x̂₀ = ϵᵧ
elseif prediction_type == VELOCITY
# arxiv:2202.00512 Eq. 31
x̂₀ = ⎷α̅ₜ .* xₜ - ⎷β̅ₜ .* ϵᵧ
else
throw("unimplemented prediction type")
end
return x̂₀
end
function get_variance(
scheduler::DDPM,
variance_type::VarianceType,
xₜ::AbstractArray,
t::AbstractArray,
)
βₜ = _extract(scheduler.β[t], xₜ)
β̅ₜ = _extract(scheduler.β̅[t], xₜ)
β̅ₜ₋₁ = _extract(scheduler.β̅₋₁[t], xₜ)
if variance_type == FIXED_SMALL
# arxiv:2006.11239 Eq. 6
# arxiv:2208.11970 Eq. 70
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ
elseif variance_type == FIXED_SMALL_LOG
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ
σₜ = log.(σₜ)
elseif variance_type == FIXED_LARGE
σₜ = βₜ
elseif variance_type == FIXED_LARGE_LOG
σₜ = βₜ
σₜ = log.(σₜ)
else
throw("unimplemented variance type")
end
return σₜ
end