♻️ rework docstrings and variable names

This commit is contained in:
Laureηt 2023-07-29 15:18:42 +02:00
parent 43695ffcf7
commit f6ecfdbfda
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
7 changed files with 162 additions and 128 deletions

View file

@ -46,7 +46,7 @@ function Base.show(io::IO, c::ConditionalChain)
print(io, ")")
end
function _big_show(io::IO, m::ConditionalChain{T}, indent::Int=0, name=nothing) where {T<:NamedTuple}
function _big_show(io::IO, m::ConditionalChain{T}, indent::Integer=0, name=nothing) where {T<:NamedTuple}
println(io, " "^indent, isnothing(name) ? "" : "$name = ", "ConditionalChain(")
for k in Base.keys(m.layers)
_big_show(io, m.layers[k], indent + 2, k)

View file

@ -7,12 +7,12 @@ end
Flux.@functor SinusoidalPositionEmbedding
Flux.trainable(emb::SinusoidalPositionEmbedding) = () # mark it as an non-trainable array
function SinusoidalPositionEmbedding(in::Int, out::Int)
function SinusoidalPositionEmbedding(in::Integer, out::Integer)
W = make_positional_embedding(out, in)
SinusoidalPositionEmbedding(W)
end
function make_positional_embedding(dim_embedding::Int, seq_length::Int=1000; n::Int=10000)
function make_positional_embedding(dim_embedding::Integer, seq_length::Integer=1000; n::Integer=10000)
embedding = Matrix{Float32}(undef, dim_embedding, seq_length)
for pos in 1:seq_length
for row in 0:2:(dim_embedding-1)

View file

@ -8,7 +8,7 @@ using ProgressMeter
include("Embeddings.jl")
include("ConditionalChain.jl")
function make_spiral(rng::AbstractRNG, n_samples::Int=1000)
function make_spiral(rng::AbstractRNG, n_samples::Integer=1000)
t_min = 1.5π
t_max = 4.5π
@ -20,7 +20,7 @@ function make_spiral(rng::AbstractRNG, n_samples::Int=1000)
permutedims([x y], (2, 1))
end
make_spiral(n_samples::Int=1000) = make_spiral(Random.GLOBAL_RNG, n_samples)
make_spiral(n_samples::Integer=1000) = make_spiral(Random.GLOBAL_RNG, n_samples)
function normalize_zero_to_one(x)
x_min, x_max = extrema(x)
@ -42,7 +42,9 @@ scatter(data[1, :], data[2, :],
num_timesteps = 100
scheduler = Diffusers.DDPM(
Vector{Float64},
Diffusers.cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0),
Diffusers.rescale_zero_terminal_snr(
Diffusers.cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0)
)
)
noise = randn(size(data))
@ -104,7 +106,7 @@ model = ConditionalChain(
model(data, [100])
num_epochs = 1000
num_epochs = 100
loss = Flux.Losses.mse
opt = Flux.setup(Adam(0.0001), model)
dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true);

View file

@ -3,117 +3,122 @@ import NNlib: sigmoid
"""
Linear beta schedule.
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
## Input
* T (`Int`): number of timesteps
* β_1 (`Real := 0.0001f0`): initial value of β
* β_T (`Real := 0.02f0`): final value of β
* `T::Integer`: number of timesteps
* `β₁::Real=0.0001f0`: initial (t=1) value of β
* `β₋₁::Real=0.02f0`: final (t=T) value of β
## Output
* β (`Vector{Real}`): β_t values at each timestep t
* `β::Vector{Real}`: βₜ values at each timestep t
## References
* [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
"""
function linear_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0)
return range(start=β_1, stop=β_T, length=T)
function linear_beta_schedule(T::Integer, β₁::Real=0.0001f0, β₋₁::Real=0.02f0)
return range(start=β, stop=β₋₁, length=T)
end
"""
Scaled linear beta schedule.
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
## Input
* T (`Int`): number of timesteps
* β_1 (`Real := 0.0001f0`): initial value of β
* β_T (`Real := 0.02f0`): final value of β
* `T::Int`: number of timesteps
* `β₁::Real=0.0001f0`: initial value of β
* `β₋₁::Real=0.02f0`: final value of β
## Output
* β (`Vector{Real}`): β_t values at each timestep t
* `β::Vector{Real}`: βₜ values at each timestep t
## References
* [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
"""
function scaled_linear_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0)
return range(start=β_1^0.5, stop=β_T^0.5, length=T) .^ 2
function scaled_linear_beta_schedule(T::Integer, β₁::Real=0.0001f0, β₋₁::Real=0.02f0)
return range(start=β^0.5, stop=β₋₁^0.5, length=T) .^ 2
end
"""
Sigmoid beta schedule.
cf. [[2203.02923] GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation](https://arxiv.org/abs/2203.02923)
and [github.com:MinkaiXu/GeoDiff](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca48045a2f7abfccd7f0df449e45eb6eae638/models/epsnet/diffusion.py#L57)
## Input
* T (`Int`): number of timesteps
* β_1 (`Real := 0.0001f0`): initial value of β
* β_T (`Real := 0.02f0`): final value of β
* `T::Int`: number of timesteps
* `β₁::Real=0.0001f0`: initial value of β
* `β₋₁::Real=0.02f0`: final value of β
## Output
* β (`Vector{Real}`): β_t values at each timestep t
* `β::Vector{Real}`: βₜ values at each timestep t
## References
* [[2203.02923] GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation](https://arxiv.org/abs/2203.02923)
* [github.com:MinkaiXu/GeoDiff](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca48045a2f7abfccd7f0df449e45eb6eae638/models/epsnet/diffusion.py#L57)
"""
function sigmoid_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0)
function sigmoid_beta_schedule(T::Integer, β₁::Real=0.0001f0, β₋₁::Real=0.02f0)
x = range(start=-6, stop=6, length=T)
return sigmoid(x) .* (β_T - β_1) .+ β_1
return sigmoid(x) .* (β₋₁ - β₁) .+ β₁
end
"""
Cosine beta schedule.
cf. [[2102.09672] Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672)
## Input
* T (`Int`): number of timesteps
* β_max (`Real := 0.999f0`): maximum value of β
* ϵ (`Real := 1e-3f0`): small value used to avoid division by zero
* `T::Int`: number of timesteps
* `βₘₐₓ::Real=0.999f0`: maximum value of β
* `ϵ::Real=1e-3f0`: small value used to avoid division by zero
## Output
* β (`Vector{Real}`): β_t values at each timestep t
* `β::Vector{Real}`: βₜ values at each timestep t
## References
* [[2102.09672] Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672)
* [github:openai/improved-diffusion](https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L36)
"""
function cosine_beta_schedule(T::Int, β_max::Real=0.999f0, ϵ::Real=0.001f0)
α_bar(t) = cos((t + ϵ) / (1 + ϵ) * π / 2)^2
function cosine_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0, ϵ::Real=0.001f0)
α̅(t) = cos((t / T + ϵ) / (1 + ϵ) * π / 2)^2
β = Float32[]
β = Vector{Real}(undef, T)
for t in 1:T
t1 = (t - 1) / T
t2 = t / T
αₜ = α̅(t) / α̅(t - 1)
β_t = 1 - α_bar(t2) / α_bar(t1)
β_t = min(β_max, β_t)
β= 1 - αₜ
β= min(βₘₐₓ, βₜ)
push!(β, β_t)
β[t] = βₜ
end
return β
end
"""
Rescale betas to have zero terminal SNR.
cf. [[2305.08891] Rescaling Diffusion Models](https://arxiv.org/abs/2305.08891) (Algorithm 1)
Rescale betas to have zero terminal Signal to Noise Ratio (SNR).
## Input
* β (`AbstractArray`): β_t values at each timestep t
* `β::AbstractArray`: βₜ values at each timestep t
## Output
* β (`Vector{Real}`): rescaled β_t values at each timestep t
* `β::Vector{Real}`: rescaled βₜ values at each timestep t
## References
* [[2305.08891] Rescaling Diffusion Models](https://arxiv.org/abs/2305.08891) (Alg. 1)
"""
function rescale_zero_terminal_snr(β::AbstractArray)
# convert β to sqrt_α_cumprods
# convert β to ⎷α̅
α = 1 .- β
α_cumprod = cumprod(α)
sqrt_α_cumprods = sqrt.(α_cumprod)
α̅ = cumprod(α)
⎷α̅ = sqrt.(α̅)
# store old extrema values
sqrt_α_cumprod_1 = sqrt_α_cumprods[1]
sqrt_α_cumprod_T = sqrt_α_cumprods[end]
⎷α̅₁ = ⎷α̅[1]
⎷α̅₋₁ = ⎷α̅[end]
# shift last timestep to zero
sqrt_α_cumprods .-= sqrt_α_cumprod_T
⎷α̅ .-= ⎷α̅₋₁
# scale so that first timestep reaches old values
sqrt_α_cumprods *= sqrt_α_cumprod_1 / (sqrt_α_cumprod_1 - sqrt_α_cumprod_T)
⎷α̅ *= ⎷α̅₁ / (⎷α̅₁ - ⎷α̅₋₁)
# convert back sqrt_α_cumprods to β
α_cumprod = sqrt_α_cumprods .^ 2
α = α_cumprod[2:end] ./ α_cumprod[1:end-1]
α = vcat(α_cumprod[1], α)
# convert back ⎷α̅ to β
α̅ = ⎷α̅ .^ 2
α = α̅[2:end] ./ α̅[1:end-1]
α = vcat(α̅[1], α)
β = 1 .- α
return β

View file

@ -3,39 +3,65 @@ include("Schedulers.jl")
"""
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
## References
* [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
"""
struct DDPM{V<:AbstractVector} <: Scheduler
# number of diffusion steps used to train the model.
T_train::Int
T::Integer # length of markov chain
# the betas used for the diffusion steps
β::V
β::V # beta variance schedule
α::V # 1 - beta
# internal variables used for computation (derived from β)
α::V
α_cumprods::V
α_cumprod_prevs::V
sqrt_α_cumprods::V
sqrt_one_minus_α_cumprods::V
⎷α::V # square root of α
⎷β::V # square root of β
α̅::V # cumulative product of α
β̅::V # 1 - α̅ (≠ cumprod(β))
α̅₋₁::V # right-shifted α̅
β̅₋₁::V # 1 - α̅₋₁
⎷α̅::V # square root of α̅
⎷β̅::V # square root of β̅
⎷α̅₋₁::V # square root of α̅₋₁
⎷β̅₋₁::V # square root of β̅₋₁
end
function DDPM(V::DataType, β::AbstractVector)
α = 1 .- β
α_cumprods = cumprod(α)
α_cumprod_prevs = [1, (α_cumprods[1:end-1])...]
T = length(β)
sqrt_α_cumprods = sqrt.(α_cumprods)
sqrt_one_minus_α_cumprods = sqrt.(1 .- α_cumprods)
α = 1 .- β
⎷α = sqrt.(α)
⎷β = sqrt.(β)
α̅ = cumprod(α)
β̅ = 1 .- α̅
α̅₋₁ = [1, (α̅[1:end-1])...]
β̅₋₁ = 1 .- α̅₋₁
⎷α̅ = sqrt.(α̅)
⎷β̅ = sqrt.(β̅)
⎷α̅₋₁ = sqrt.(α̅₋₁)
⎷β̅₋₁ = sqrt.(β̅₋₁)
DDPM{V}(
length(β),
T,
β,
α,
α_cumprods,
α_cumprod_prevs,
sqrt_α_cumprods,
sqrt_one_minus_α_cumprods,
⎷α,
⎷β,
α̅,
β̅,
α̅₋₁,
β̅₋₁,
⎷α̅,
⎷β̅,
⎷α̅₋₁,
⎷β̅₋₁,
)
end
@ -43,47 +69,47 @@ end
Remove noise from model output using the backward diffusion process.
## Input
* scheduler (`DDPM`): scheduler to use
* sample (`AbstractArray`): sample to remove noise from, i.e. model_input
* model_output (`AbstractArray`): predicted noise from the model
* timesteps (`AbstractArray`): timesteps to remove noise from
* `scheduler::DDPM`: scheduler to use
* `xₜ::AbstractArray`: sample to be denoised
* `ϵᵧ::AbstractArray`: predicted noise to remove
* `t::AbstractArray`: timestep t of `xₜ`
## Output
* pred_prev_sample (`AbstractArray`): denoised sample at t=t-1
* x_0_pred (`AbstractArray`): denoised sample at t=0
* `xₜ₋₁::AbstractArray`: denoised sample at t=t-1
* `x̂₀::AbstractArray`: denoised sample at t=0
"""
function step(
scheduler::DDPM,
sample::AbstractArray,
model_output::AbstractArray,
timesteps::AbstractArray,
xₜ::AbstractArray,
ϵᵧ::AbstractArray,
t::AbstractArray,
)
# 1. compute alphas, betas
α_cumprod_t = scheduler.α_cumprods[timesteps]
α_cumprod_t_prev = scheduler.α_cumprods[timesteps.-1]
β_cumprod_t = 1 .- α_cumprod_t
β_cumprod_t_prev = 1 .- α_cumprod_t_prev
current_α_t = α_cumprod_t ./ α_cumprod_t_prev
current_β_t = 1 .- current_α_t
# retreive scheduler variables at timesteps t
βₜ = scheduler.β[t]
β̅ₜ = scheduler.β̅[t]
β̅ₜ₋₁ = scheduler.β̅₋₁[t]
⎷αₜ = scheduler.⎷α[t]
⎷α̅ₜ = scheduler.⎷α̅[t]
⎷α̅ₜ₋₁ = scheduler.⎷α̅₋₁[t]
⎷β̅ₜ = scheduler.⎷β̅[t]
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
# epsilon prediction type
# print shapes of thingies
x_0_pred = (sample - sqrt.(β_cumprod_t)' .* model_output) ./ sqrt.(α_cumprod_t)'
# compute predicted previous sample x̂₀
# arxiv:2006.11239 Eq. 15
# arxiv:2208.11970 Eq. 115
x̂₀ = (xₜ - ⎷β̅ₜ' .* ϵᵧ) ./ ⎷α̅ₜ'
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (sqrt.(α_cumprod_t_prev) .* current_β_t) ./ β_cumprod_t
current_sample_coeff = sqrt.(current_α_t) .* β_cumprod_t_prev ./ β_cumprod_t
# compute predicted previous sample μ̃ₜ
# arxiv:2006.11239 Eq. 7
# arxiv:2208.11970 Eq. 84
λ₀ = ⎷α̅ₜ₋₁ .* βₜ ./ β̅ₜ
λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ # TODO: this could be stored in the scheduler
μ̃ₜ = λ₀' .* x̂₀ + λₜ' .* xₜ
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample
# sample predicted previous sample xₜ₋₁
# arxiv:2006.11239 Eq. 6
# arxiv:2208.11970 Eq. 70
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler
xₜ₋₁ = μ̃ₜ + σₜ' .* randn(size(ϵᵧ))
# 6. Add noise
variance = sqrt.(scheduler.β[timesteps])' .* randn(size(model_output))
pred_prev_sample = pred_prev_sample + variance
return pred_prev_sample, x_0_pred
return xₜ₋₁, x̂₀
end

View file

@ -6,5 +6,6 @@ include("BetaSchedulers.jl")
# concrete types
include("DDPM.jl")
# include("DDIM.jl")
end # module Diffusers

View file

@ -6,24 +6,24 @@ Add noise to clean data using the forward diffusion process.
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) (Eq. 4)
## Input
* scheduler (`Scheduler`): scheduler to use
* clean_data (`AbstractArray`): clean data to add noise to
* noise (`AbstractArray`): noise to add to clean data
* timesteps (`AbstractArray`): timesteps used to weight the noise
* `scheduler::Scheduler`: scheduler to use
* `clean_data::AbstractArray`: clean data to add noise to
* `noise::AbstractArray`: noise to add to clean data
* `timesteps::AbstractArray`: timesteps used to weight the noise
## Output
* noisy_data (`AbstractArray`): noisy data at the given timesteps
* `noisy_data::AbstractArray`: noisy data at the given timesteps
"""
function add_noise(
scheduler::Scheduler,
clean_data::AbstractArray,
noise::AbstractArray,
timesteps::AbstractArray,
x₀::AbstractArray,
ϵ::AbstractArray,
t::AbstractArray,
)
sqrt_α_cumprod_t = scheduler.sqrt_α_cumprods[timesteps]
sqrt_one_minus_α_cumprod_t = scheduler.sqrt_one_minus_α_cumprods[timesteps]
⎷α̅ₜ = scheduler.⎷α̅[t]
⎷β̅ₜ = scheduler.⎷β̅[t]
noisy_data = sqrt_α_cumprod_t' .* clean_data + sqrt_one_minus_α_cumprod_t' .* noise
xₜ = ⎷α̅ₜ' .* x₀ + ⎷β̅ₜ' .* ϵ
return noisy_data
return xₜ
end