mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-09-18 18:45:28 +00:00
♻️ (Schedulers) add variance and prediction types
This commit is contained in:
parent
1e5ceb5140
commit
5ec22c47b2
|
@ -5,9 +5,11 @@ version = "0.1.0"
|
||||||
|
|
||||||
[deps]
|
[deps]
|
||||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||||
|
ShiftedArrays = "1277b4bf-5013-50f5-be3d-901d8477a67a"
|
||||||
|
|
||||||
[compat]
|
[compat]
|
||||||
NNlib = "0.9"
|
NNlib = "0.9"
|
||||||
|
ShiftedArrays = "2.0"
|
||||||
julia = "1"
|
julia = "1"
|
||||||
|
|
||||||
[extras]
|
[extras]
|
||||||
|
|
|
@ -4,6 +4,9 @@ Abstract type for schedulers.
|
||||||
"""
|
"""
|
||||||
abstract type Scheduler end
|
abstract type Scheduler end
|
||||||
|
|
||||||
|
@enum VarianceType FIXED_SMALL FIXED_SMALL_LOG FIXED_LARGE FIXED_LARGE_LOG LEARNED
|
||||||
|
@enum PredictionType EPSILON SAMPLE VELOCITY
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Add noise to clean data using the forward diffusion process.
|
Add noise to clean data using the forward diffusion process.
|
||||||
|
|
||||||
|
@ -42,3 +45,25 @@ function reverse(
|
||||||
ϵᵧ::AbstractArray,
|
ϵᵧ::AbstractArray,
|
||||||
t::AbstractArray,
|
t::AbstractArray,
|
||||||
) end
|
) end
|
||||||
|
|
||||||
|
"""
|
||||||
|
Compute the velocity of the diffusion process.
|
||||||
|
|
||||||
|
## Input
|
||||||
|
* `scheduler::Scheduler`: scheduler to use
|
||||||
|
* `x₀::AbstractArray`: clean data to add noise to
|
||||||
|
* `ϵ::AbstractArray`: noise to add to clean data
|
||||||
|
* `t::AbstractArray`: timesteps used to weight the noise
|
||||||
|
|
||||||
|
## Output
|
||||||
|
* `vₜ::AbstractArray`: velocity at the given timesteps
|
||||||
|
|
||||||
|
## References
|
||||||
|
* [[2202.00512] Progressive Distillation for Fast Sampling of Diffusion Models](https://arxiv.org/abs/2202.00512) (Ann. D)
|
||||||
|
"""
|
||||||
|
function get_velocity(
|
||||||
|
scheduler::Scheduler,
|
||||||
|
x₀::AbstractArray,
|
||||||
|
ϵ::AbstractArray,
|
||||||
|
t::AbstractArray,
|
||||||
|
) end
|
||||||
|
|
|
@ -1,5 +1,18 @@
|
||||||
include("Abstract.jl")
|
include("Abstract.jl")
|
||||||
|
|
||||||
|
using ShiftedArrays
|
||||||
|
|
||||||
|
function _extract(
|
||||||
|
target::AbstractArray,
|
||||||
|
reference::AbstractArray,
|
||||||
|
)
|
||||||
|
new_size = tuple(
|
||||||
|
fill(1, ndims(reference) - 1)...,
|
||||||
|
size(reference, ndims(reference))
|
||||||
|
)
|
||||||
|
return reshape(target, new_size)
|
||||||
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
|
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
|
||||||
|
|
||||||
|
@ -9,8 +22,8 @@ Denoising Diffusion Probabilistic Models (DDPM) scheduler.
|
||||||
struct DDPM{V<:AbstractVector} <: Scheduler
|
struct DDPM{V<:AbstractVector} <: Scheduler
|
||||||
T::Integer # length of markov chain
|
T::Integer # length of markov chain
|
||||||
|
|
||||||
β::V # beta variance schedule
|
|
||||||
α::V # 1 - beta
|
α::V # 1 - beta
|
||||||
|
β::V # beta variance schedule
|
||||||
|
|
||||||
⎷α::V # square root of α
|
⎷α::V # square root of α
|
||||||
⎷β::V # square root of β
|
⎷β::V # square root of β
|
||||||
|
@ -39,7 +52,7 @@ function DDPM(β::AbstractVector)
|
||||||
α̅ = cumprod(α)
|
α̅ = cumprod(α)
|
||||||
β̅ = 1 .- α̅
|
β̅ = 1 .- α̅
|
||||||
|
|
||||||
α̅₋₁ = [1, (α̅[1:end-1])...]
|
α̅₋₁ = ShiftedArray(α̅, 1, default=1)
|
||||||
β̅₋₁ = 1 .- α̅₋₁
|
β̅₋₁ = 1 .- α̅₋₁
|
||||||
|
|
||||||
⎷α̅ = sqrt.(α̅)
|
⎷α̅ = sqrt.(α̅)
|
||||||
|
@ -50,18 +63,12 @@ function DDPM(β::AbstractVector)
|
||||||
|
|
||||||
DDPM{typeof(β)}(
|
DDPM{typeof(β)}(
|
||||||
T,
|
T,
|
||||||
β,
|
α, β,
|
||||||
α,
|
⎷α, ⎷β,
|
||||||
⎷α,
|
α̅, β̅,
|
||||||
⎷β,
|
α̅₋₁, β̅₋₁,
|
||||||
α̅,
|
⎷α̅, ⎷β̅,
|
||||||
β̅,
|
⎷α̅₋₁, ⎷β̅₋₁,
|
||||||
α̅₋₁,
|
|
||||||
β̅₋₁,
|
|
||||||
⎷α̅,
|
|
||||||
⎷β̅,
|
|
||||||
⎷α̅₋₁,
|
|
||||||
⎷β̅₋₁,
|
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -72,12 +79,8 @@ function forward(
|
||||||
t::AbstractArray,
|
t::AbstractArray,
|
||||||
)
|
)
|
||||||
# retreive scheduler variables at timesteps t
|
# retreive scheduler variables at timesteps t
|
||||||
reshape_size = tuple(
|
⎷α̅ₜ = _extract(scheduler.⎷α̅[t], x₀)
|
||||||
fill(1, ndims(x₀) - 1)...,
|
⎷β̅ₜ = _extract(scheduler.⎷β̅[t], x₀)
|
||||||
size(t, 1)
|
|
||||||
)
|
|
||||||
⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size)
|
|
||||||
⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size)
|
|
||||||
|
|
||||||
# noisify clean data
|
# noisify clean data
|
||||||
# arxiv:2006.11239 Eq. 4
|
# arxiv:2006.11239 Eq. 4
|
||||||
|
@ -91,37 +94,103 @@ function reverse(
|
||||||
xₜ::AbstractArray,
|
xₜ::AbstractArray,
|
||||||
ϵᵧ::AbstractArray,
|
ϵᵧ::AbstractArray,
|
||||||
t::AbstractArray,
|
t::AbstractArray,
|
||||||
|
prediction_type::PredictionType=EPSILON,
|
||||||
|
variance_type::VarianceType=FIXED_SMALL,
|
||||||
)
|
)
|
||||||
# retreive scheduler variables at timesteps t
|
# retreive scheduler variables at timesteps t
|
||||||
reshape_size = tuple(
|
βₜ = _extract(scheduler.β[t], xₜ)
|
||||||
fill(1, ndims(xₜ) - 1)...,
|
β̅ₜ = _extract(scheduler.β̅[t], xₜ)
|
||||||
size(t, 1)
|
β̅ₜ₋₁ = _extract(scheduler.β̅₋₁[t], xₜ)
|
||||||
)
|
⎷αₜ = _extract(scheduler.⎷α[t], xₜ)
|
||||||
βₜ = reshape(scheduler.β[t], reshape_size)
|
⎷α̅ₜ₋₁ = _extract(scheduler.⎷α̅₋₁[t], xₜ)
|
||||||
β̅ₜ = reshape(scheduler.β̅[t], reshape_size)
|
|
||||||
β̅ₜ₋₁ = reshape(scheduler.β̅₋₁[t], reshape_size)
|
|
||||||
⎷αₜ = reshape(scheduler.⎷α[t], reshape_size)
|
|
||||||
⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size)
|
|
||||||
⎷α̅ₜ₋₁ = reshape(scheduler.⎷α̅₋₁[t], reshape_size)
|
|
||||||
⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size)
|
|
||||||
|
|
||||||
# compute predicted previous sample x̂₀
|
# compute x₀ (approximation)
|
||||||
# arxiv:2006.11239 Eq. 15
|
x̂₀ = get_prediction(scheduler, prediction_type, xₜ, ϵᵧ, t)
|
||||||
# arxiv:2208.11970 Eq. 115
|
|
||||||
x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ
|
|
||||||
|
|
||||||
# compute predicted previous sample μ̃ₜ
|
# compute μₜ (approximation)
|
||||||
# arxiv:2006.11239 Eq. 7
|
# arxiv:2006.11239 Eq. 7
|
||||||
# arxiv:2208.11970 Eq. 84
|
# arxiv:2208.11970 Eq. 84
|
||||||
λ₀ = ⎷α̅ₜ₋₁ .* βₜ ./ β̅ₜ
|
λ₀ = ⎷α̅ₜ₋₁ .* βₜ ./ β̅ₜ
|
||||||
λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ # TODO: this could be stored in the scheduler
|
λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ
|
||||||
μ̃ₜ = λ₀ .* x̂₀ + λₜ .* xₜ
|
μ̃ₜ = λ₀ .* x̂₀ + λₜ .* xₜ
|
||||||
|
|
||||||
# sample predicted previous sample xₜ₋₁
|
# compute σ² (exact)
|
||||||
# arxiv:2006.11239 Eq. 6
|
σ²ₜ = get_variance(scheduler, variance_type, xₜ, t)
|
||||||
# arxiv:2208.11970 Eq. 70
|
|
||||||
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler
|
# sample xₜ₋₁ using μₜ and σ²
|
||||||
xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ))
|
σₜ = sqrt.(σ²ₜ)
|
||||||
|
ϵ = randn(size(ϵᵧ))
|
||||||
|
xₜ₋₁ = μ̃ₜ + σₜ .* ϵ
|
||||||
|
|
||||||
return xₜ₋₁, x̂₀
|
return xₜ₋₁, x̂₀
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function get_velocity(
|
||||||
|
scheduler::DDPM,
|
||||||
|
x₀::AbstractArray,
|
||||||
|
ϵ::AbstractArray,
|
||||||
|
t::AbstractArray,
|
||||||
|
)
|
||||||
|
⎷α̅ₜ = _extract(scheduler.⎷α̅[t], x₀)
|
||||||
|
⎷β̅ₜ = _extract(scheduler.⎷β̅[t], x₀)
|
||||||
|
|
||||||
|
vₜ = ⎷α̅ₜ .* ϵ - ⎷β̅ₜ .* x₀
|
||||||
|
|
||||||
|
return vₜ
|
||||||
|
end
|
||||||
|
|
||||||
|
function get_prediction(
|
||||||
|
scheduler::DDPM,
|
||||||
|
prediction_type::PredictionType,
|
||||||
|
xₜ::AbstractArray,
|
||||||
|
ϵᵧ::AbstractArray,
|
||||||
|
t::AbstractArray,
|
||||||
|
)
|
||||||
|
⎷α̅ₜ = _extract(scheduler.⎷α̅[t], xₜ)
|
||||||
|
⎷β̅ₜ = _extract(scheduler.⎷β̅[t], xₜ)
|
||||||
|
|
||||||
|
if prediction_type == EPSILON
|
||||||
|
# arxiv:2006.11239 Eq. 15
|
||||||
|
# arxiv:2208.11970 Eq. 115
|
||||||
|
x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ
|
||||||
|
elseif prediction_type == SAMPLE
|
||||||
|
# arxiv:2208.11970 Eq. 99
|
||||||
|
x̂₀ = ϵᵧ
|
||||||
|
elseif prediction_type == VELOCITY
|
||||||
|
# arxiv:2202.00512 Eq. 31
|
||||||
|
x̂₀ = ⎷α̅ₜ .* xₜ - ⎷β̅ₜ .* ϵᵧ
|
||||||
|
else
|
||||||
|
throw("unimplemented prediction type")
|
||||||
|
end
|
||||||
|
|
||||||
|
return x̂₀
|
||||||
|
end
|
||||||
|
|
||||||
|
function get_variance(
|
||||||
|
scheduler::DDPM,
|
||||||
|
variance_type::VarianceType,
|
||||||
|
xₜ::AbstractArray,
|
||||||
|
t::AbstractArray,
|
||||||
|
)
|
||||||
|
βₜ = _extract(scheduler.β[t], xₜ)
|
||||||
|
β̅ₜ = _extract(scheduler.β̅[t], xₜ)
|
||||||
|
β̅ₜ₋₁ = _extract(scheduler.β̅₋₁[t], xₜ)
|
||||||
|
|
||||||
|
if variance_type == FIXED_SMALL
|
||||||
|
# arxiv:2006.11239 Eq. 6
|
||||||
|
# arxiv:2208.11970 Eq. 70
|
||||||
|
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ
|
||||||
|
elseif variance_type == FIXED_SMALL_LOG
|
||||||
|
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ
|
||||||
|
σₜ = log.(σₜ)
|
||||||
|
elseif variance_type == FIXED_LARGE
|
||||||
|
σₜ = βₜ
|
||||||
|
elseif variance_type == FIXED_LARGE_LOG
|
||||||
|
σₜ = βₜ
|
||||||
|
σₜ = log.(σₜ)
|
||||||
|
else
|
||||||
|
throw("unimplemented variance type")
|
||||||
|
end
|
||||||
|
|
||||||
|
return σₜ
|
||||||
|
end
|
||||||
|
|
Loading…
Reference in a new issue