mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-09-16 17:45:29 +00:00
♻️ (Schedulers) add variance and prediction types
This commit is contained in:
parent
1e5ceb5140
commit
5ec22c47b2
|
@ -5,9 +5,11 @@ version = "0.1.0"
|
|||
|
||||
[deps]
|
||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
ShiftedArrays = "1277b4bf-5013-50f5-be3d-901d8477a67a"
|
||||
|
||||
[compat]
|
||||
NNlib = "0.9"
|
||||
ShiftedArrays = "2.0"
|
||||
julia = "1"
|
||||
|
||||
[extras]
|
||||
|
|
|
@ -4,6 +4,9 @@ Abstract type for schedulers.
|
|||
"""
|
||||
abstract type Scheduler end
|
||||
|
||||
@enum VarianceType FIXED_SMALL FIXED_SMALL_LOG FIXED_LARGE FIXED_LARGE_LOG LEARNED
|
||||
@enum PredictionType EPSILON SAMPLE VELOCITY
|
||||
|
||||
"""
|
||||
Add noise to clean data using the forward diffusion process.
|
||||
|
||||
|
@ -42,3 +45,25 @@ function reverse(
|
|||
ϵᵧ::AbstractArray,
|
||||
t::AbstractArray,
|
||||
) end
|
||||
|
||||
"""
|
||||
Compute the velocity of the diffusion process.
|
||||
|
||||
## Input
|
||||
* `scheduler::Scheduler`: scheduler to use
|
||||
* `x₀::AbstractArray`: clean data to add noise to
|
||||
* `ϵ::AbstractArray`: noise to add to clean data
|
||||
* `t::AbstractArray`: timesteps used to weight the noise
|
||||
|
||||
## Output
|
||||
* `vₜ::AbstractArray`: velocity at the given timesteps
|
||||
|
||||
## References
|
||||
* [[2202.00512] Progressive Distillation for Fast Sampling of Diffusion Models](https://arxiv.org/abs/2202.00512) (Ann. D)
|
||||
"""
|
||||
function get_velocity(
|
||||
scheduler::Scheduler,
|
||||
x₀::AbstractArray,
|
||||
ϵ::AbstractArray,
|
||||
t::AbstractArray,
|
||||
) end
|
||||
|
|
|
@ -1,5 +1,18 @@
|
|||
include("Abstract.jl")
|
||||
|
||||
using ShiftedArrays
|
||||
|
||||
function _extract(
|
||||
target::AbstractArray,
|
||||
reference::AbstractArray,
|
||||
)
|
||||
new_size = tuple(
|
||||
fill(1, ndims(reference) - 1)...,
|
||||
size(reference, ndims(reference))
|
||||
)
|
||||
return reshape(target, new_size)
|
||||
end
|
||||
|
||||
"""
|
||||
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
|
||||
|
||||
|
@ -9,8 +22,8 @@ Denoising Diffusion Probabilistic Models (DDPM) scheduler.
|
|||
struct DDPM{V<:AbstractVector} <: Scheduler
|
||||
T::Integer # length of markov chain
|
||||
|
||||
β::V # beta variance schedule
|
||||
α::V # 1 - beta
|
||||
β::V # beta variance schedule
|
||||
|
||||
⎷α::V # square root of α
|
||||
⎷β::V # square root of β
|
||||
|
@ -39,7 +52,7 @@ function DDPM(β::AbstractVector)
|
|||
α̅ = cumprod(α)
|
||||
β̅ = 1 .- α̅
|
||||
|
||||
α̅₋₁ = [1, (α̅[1:end-1])...]
|
||||
α̅₋₁ = ShiftedArray(α̅, 1, default=1)
|
||||
β̅₋₁ = 1 .- α̅₋₁
|
||||
|
||||
⎷α̅ = sqrt.(α̅)
|
||||
|
@ -50,18 +63,12 @@ function DDPM(β::AbstractVector)
|
|||
|
||||
DDPM{typeof(β)}(
|
||||
T,
|
||||
β,
|
||||
α,
|
||||
⎷α,
|
||||
⎷β,
|
||||
α̅,
|
||||
β̅,
|
||||
α̅₋₁,
|
||||
β̅₋₁,
|
||||
⎷α̅,
|
||||
⎷β̅,
|
||||
⎷α̅₋₁,
|
||||
⎷β̅₋₁,
|
||||
α, β,
|
||||
⎷α, ⎷β,
|
||||
α̅, β̅,
|
||||
α̅₋₁, β̅₋₁,
|
||||
⎷α̅, ⎷β̅,
|
||||
⎷α̅₋₁, ⎷β̅₋₁,
|
||||
)
|
||||
end
|
||||
|
||||
|
@ -72,12 +79,8 @@ function forward(
|
|||
t::AbstractArray,
|
||||
)
|
||||
# retreive scheduler variables at timesteps t
|
||||
reshape_size = tuple(
|
||||
fill(1, ndims(x₀) - 1)...,
|
||||
size(t, 1)
|
||||
)
|
||||
⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size)
|
||||
⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size)
|
||||
⎷α̅ₜ = _extract(scheduler.⎷α̅[t], x₀)
|
||||
⎷β̅ₜ = _extract(scheduler.⎷β̅[t], x₀)
|
||||
|
||||
# noisify clean data
|
||||
# arxiv:2006.11239 Eq. 4
|
||||
|
@ -91,37 +94,103 @@ function reverse(
|
|||
xₜ::AbstractArray,
|
||||
ϵᵧ::AbstractArray,
|
||||
t::AbstractArray,
|
||||
prediction_type::PredictionType=EPSILON,
|
||||
variance_type::VarianceType=FIXED_SMALL,
|
||||
)
|
||||
# retreive scheduler variables at timesteps t
|
||||
reshape_size = tuple(
|
||||
fill(1, ndims(xₜ) - 1)...,
|
||||
size(t, 1)
|
||||
)
|
||||
βₜ = reshape(scheduler.β[t], reshape_size)
|
||||
β̅ₜ = reshape(scheduler.β̅[t], reshape_size)
|
||||
β̅ₜ₋₁ = reshape(scheduler.β̅₋₁[t], reshape_size)
|
||||
⎷αₜ = reshape(scheduler.⎷α[t], reshape_size)
|
||||
⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size)
|
||||
⎷α̅ₜ₋₁ = reshape(scheduler.⎷α̅₋₁[t], reshape_size)
|
||||
⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size)
|
||||
βₜ = _extract(scheduler.β[t], xₜ)
|
||||
β̅ₜ = _extract(scheduler.β̅[t], xₜ)
|
||||
β̅ₜ₋₁ = _extract(scheduler.β̅₋₁[t], xₜ)
|
||||
⎷αₜ = _extract(scheduler.⎷α[t], xₜ)
|
||||
⎷α̅ₜ₋₁ = _extract(scheduler.⎷α̅₋₁[t], xₜ)
|
||||
|
||||
# compute predicted previous sample x̂₀
|
||||
# arxiv:2006.11239 Eq. 15
|
||||
# arxiv:2208.11970 Eq. 115
|
||||
x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ
|
||||
# compute x₀ (approximation)
|
||||
x̂₀ = get_prediction(scheduler, prediction_type, xₜ, ϵᵧ, t)
|
||||
|
||||
# compute predicted previous sample μ̃ₜ
|
||||
# compute μₜ (approximation)
|
||||
# arxiv:2006.11239 Eq. 7
|
||||
# arxiv:2208.11970 Eq. 84
|
||||
λ₀ = ⎷α̅ₜ₋₁ .* βₜ ./ β̅ₜ
|
||||
λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ # TODO: this could be stored in the scheduler
|
||||
λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ
|
||||
μ̃ₜ = λ₀ .* x̂₀ + λₜ .* xₜ
|
||||
|
||||
# sample predicted previous sample xₜ₋₁
|
||||
# arxiv:2006.11239 Eq. 6
|
||||
# arxiv:2208.11970 Eq. 70
|
||||
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler
|
||||
xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ))
|
||||
# compute σ² (exact)
|
||||
σ²ₜ = get_variance(scheduler, variance_type, xₜ, t)
|
||||
|
||||
# sample xₜ₋₁ using μₜ and σ²
|
||||
σₜ = sqrt.(σ²ₜ)
|
||||
ϵ = randn(size(ϵᵧ))
|
||||
xₜ₋₁ = μ̃ₜ + σₜ .* ϵ
|
||||
|
||||
return xₜ₋₁, x̂₀
|
||||
end
|
||||
|
||||
function get_velocity(
|
||||
scheduler::DDPM,
|
||||
x₀::AbstractArray,
|
||||
ϵ::AbstractArray,
|
||||
t::AbstractArray,
|
||||
)
|
||||
⎷α̅ₜ = _extract(scheduler.⎷α̅[t], x₀)
|
||||
⎷β̅ₜ = _extract(scheduler.⎷β̅[t], x₀)
|
||||
|
||||
vₜ = ⎷α̅ₜ .* ϵ - ⎷β̅ₜ .* x₀
|
||||
|
||||
return vₜ
|
||||
end
|
||||
|
||||
function get_prediction(
|
||||
scheduler::DDPM,
|
||||
prediction_type::PredictionType,
|
||||
xₜ::AbstractArray,
|
||||
ϵᵧ::AbstractArray,
|
||||
t::AbstractArray,
|
||||
)
|
||||
⎷α̅ₜ = _extract(scheduler.⎷α̅[t], xₜ)
|
||||
⎷β̅ₜ = _extract(scheduler.⎷β̅[t], xₜ)
|
||||
|
||||
if prediction_type == EPSILON
|
||||
# arxiv:2006.11239 Eq. 15
|
||||
# arxiv:2208.11970 Eq. 115
|
||||
x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ
|
||||
elseif prediction_type == SAMPLE
|
||||
# arxiv:2208.11970 Eq. 99
|
||||
x̂₀ = ϵᵧ
|
||||
elseif prediction_type == VELOCITY
|
||||
# arxiv:2202.00512 Eq. 31
|
||||
x̂₀ = ⎷α̅ₜ .* xₜ - ⎷β̅ₜ .* ϵᵧ
|
||||
else
|
||||
throw("unimplemented prediction type")
|
||||
end
|
||||
|
||||
return x̂₀
|
||||
end
|
||||
|
||||
function get_variance(
|
||||
scheduler::DDPM,
|
||||
variance_type::VarianceType,
|
||||
xₜ::AbstractArray,
|
||||
t::AbstractArray,
|
||||
)
|
||||
βₜ = _extract(scheduler.β[t], xₜ)
|
||||
β̅ₜ = _extract(scheduler.β̅[t], xₜ)
|
||||
β̅ₜ₋₁ = _extract(scheduler.β̅₋₁[t], xₜ)
|
||||
|
||||
if variance_type == FIXED_SMALL
|
||||
# arxiv:2006.11239 Eq. 6
|
||||
# arxiv:2208.11970 Eq. 70
|
||||
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ
|
||||
elseif variance_type == FIXED_SMALL_LOG
|
||||
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ
|
||||
σₜ = log.(σₜ)
|
||||
elseif variance_type == FIXED_LARGE
|
||||
σₜ = βₜ
|
||||
elseif variance_type == FIXED_LARGE_LOG
|
||||
σₜ = βₜ
|
||||
σₜ = log.(σₜ)
|
||||
else
|
||||
throw("unimplemented variance type")
|
||||
end
|
||||
|
||||
return σₜ
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue