♻️ rework docstrings and variable names

This commit is contained in:
Laureηt 2023-07-29 15:18:42 +02:00
parent 43695ffcf7
commit f6ecfdbfda
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
7 changed files with 162 additions and 128 deletions

View file

@ -46,7 +46,7 @@ function Base.show(io::IO, c::ConditionalChain)
print(io, ")") print(io, ")")
end end
function _big_show(io::IO, m::ConditionalChain{T}, indent::Int=0, name=nothing) where {T<:NamedTuple} function _big_show(io::IO, m::ConditionalChain{T}, indent::Integer=0, name=nothing) where {T<:NamedTuple}
println(io, " "^indent, isnothing(name) ? "" : "$name = ", "ConditionalChain(") println(io, " "^indent, isnothing(name) ? "" : "$name = ", "ConditionalChain(")
for k in Base.keys(m.layers) for k in Base.keys(m.layers)
_big_show(io, m.layers[k], indent + 2, k) _big_show(io, m.layers[k], indent + 2, k)

View file

@ -7,12 +7,12 @@ end
Flux.@functor SinusoidalPositionEmbedding Flux.@functor SinusoidalPositionEmbedding
Flux.trainable(emb::SinusoidalPositionEmbedding) = () # mark it as an non-trainable array Flux.trainable(emb::SinusoidalPositionEmbedding) = () # mark it as an non-trainable array
function SinusoidalPositionEmbedding(in::Int, out::Int) function SinusoidalPositionEmbedding(in::Integer, out::Integer)
W = make_positional_embedding(out, in) W = make_positional_embedding(out, in)
SinusoidalPositionEmbedding(W) SinusoidalPositionEmbedding(W)
end end
function make_positional_embedding(dim_embedding::Int, seq_length::Int=1000; n::Int=10000) function make_positional_embedding(dim_embedding::Integer, seq_length::Integer=1000; n::Integer=10000)
embedding = Matrix{Float32}(undef, dim_embedding, seq_length) embedding = Matrix{Float32}(undef, dim_embedding, seq_length)
for pos in 1:seq_length for pos in 1:seq_length
for row in 0:2:(dim_embedding-1) for row in 0:2:(dim_embedding-1)

View file

@ -8,7 +8,7 @@ using ProgressMeter
include("Embeddings.jl") include("Embeddings.jl")
include("ConditionalChain.jl") include("ConditionalChain.jl")
function make_spiral(rng::AbstractRNG, n_samples::Int=1000) function make_spiral(rng::AbstractRNG, n_samples::Integer=1000)
t_min = 1.5π t_min = 1.5π
t_max = 4.5π t_max = 4.5π
@ -20,7 +20,7 @@ function make_spiral(rng::AbstractRNG, n_samples::Int=1000)
permutedims([x y], (2, 1)) permutedims([x y], (2, 1))
end end
make_spiral(n_samples::Int=1000) = make_spiral(Random.GLOBAL_RNG, n_samples) make_spiral(n_samples::Integer=1000) = make_spiral(Random.GLOBAL_RNG, n_samples)
function normalize_zero_to_one(x) function normalize_zero_to_one(x)
x_min, x_max = extrema(x) x_min, x_max = extrema(x)
@ -42,7 +42,9 @@ scatter(data[1, :], data[2, :],
num_timesteps = 100 num_timesteps = 100
scheduler = Diffusers.DDPM( scheduler = Diffusers.DDPM(
Vector{Float64}, Vector{Float64},
Diffusers.cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0), Diffusers.rescale_zero_terminal_snr(
Diffusers.cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0)
)
) )
noise = randn(size(data)) noise = randn(size(data))
@ -104,7 +106,7 @@ model = ConditionalChain(
model(data, [100]) model(data, [100])
num_epochs = 1000 num_epochs = 100
loss = Flux.Losses.mse loss = Flux.Losses.mse
opt = Flux.setup(Adam(0.0001), model) opt = Flux.setup(Adam(0.0001), model)
dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true); dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true);

View file

@ -3,117 +3,122 @@ import NNlib: sigmoid
""" """
Linear beta schedule. Linear beta schedule.
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
## Input ## Input
* T (`Int`): number of timesteps * `T::Integer`: number of timesteps
* β_1 (`Real := 0.0001f0`): initial value of β * `β₁::Real=0.0001f0`: initial (t=1) value of β
* β_T (`Real := 0.02f0`): final value of β * `β₋₁::Real=0.02f0`: final (t=T) value of β
## Output ## Output
* β (`Vector{Real}`): β_t values at each timestep t * `β::Vector{Real}`: βₜ values at each timestep t
## References
* [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
""" """
function linear_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0) function linear_beta_schedule(T::Integer, β₁::Real=0.0001f0, β₋₁::Real=0.02f0)
return range(start=β_1, stop=β_T, length=T) return range(start=β, stop=β₋₁, length=T)
end end
""" """
Scaled linear beta schedule. Scaled linear beta schedule.
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
## Input ## Input
* T (`Int`): number of timesteps * `T::Int`: number of timesteps
* β_1 (`Real := 0.0001f0`): initial value of β * `β₁::Real=0.0001f0`: initial value of β
* β_T (`Real := 0.02f0`): final value of β * `β₋₁::Real=0.02f0`: final value of β
## Output ## Output
* β (`Vector{Real}`): β_t values at each timestep t * `β::Vector{Real}`: βₜ values at each timestep t
## References
* [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
""" """
function scaled_linear_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0) function scaled_linear_beta_schedule(T::Integer, β₁::Real=0.0001f0, β₋₁::Real=0.02f0)
return range(start=β_1^0.5, stop=β_T^0.5, length=T) .^ 2 return range(start=β^0.5, stop=β₋₁^0.5, length=T) .^ 2
end end
""" """
Sigmoid beta schedule. Sigmoid beta schedule.
cf. [[2203.02923] GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation](https://arxiv.org/abs/2203.02923)
and [github.com:MinkaiXu/GeoDiff](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca48045a2f7abfccd7f0df449e45eb6eae638/models/epsnet/diffusion.py#L57)
## Input ## Input
* T (`Int`): number of timesteps * `T::Int`: number of timesteps
* β_1 (`Real := 0.0001f0`): initial value of β * `β₁::Real=0.0001f0`: initial value of β
* β_T (`Real := 0.02f0`): final value of β * `β₋₁::Real=0.02f0`: final value of β
## Output ## Output
* β (`Vector{Real}`): β_t values at each timestep t * `β::Vector{Real}`: βₜ values at each timestep t
## References
* [[2203.02923] GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation](https://arxiv.org/abs/2203.02923)
* [github.com:MinkaiXu/GeoDiff](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca48045a2f7abfccd7f0df449e45eb6eae638/models/epsnet/diffusion.py#L57)
""" """
function sigmoid_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0) function sigmoid_beta_schedule(T::Integer, β₁::Real=0.0001f0, β₋₁::Real=0.02f0)
x = range(start=-6, stop=6, length=T) x = range(start=-6, stop=6, length=T)
return sigmoid(x) .* (β_T - β_1) .+ β_1 return sigmoid(x) .* (β₋₁ - β₁) .+ β₁
end end
""" """
Cosine beta schedule. Cosine beta schedule.
cf. [[2102.09672] Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672)
## Input ## Input
* T (`Int`): number of timesteps * `T::Int`: number of timesteps
* β_max (`Real := 0.999f0`): maximum value of β * `βₘₐₓ::Real=0.999f0`: maximum value of β
* ϵ (`Real := 1e-3f0`): small value used to avoid division by zero * `ϵ::Real=1e-3f0`: small value used to avoid division by zero
## Output ## Output
* β (`Vector{Real}`): β_t values at each timestep t * `β::Vector{Real}`: βₜ values at each timestep t
## References
* [[2102.09672] Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672)
* [github:openai/improved-diffusion](https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L36)
""" """
function cosine_beta_schedule(T::Int, β_max::Real=0.999f0, ϵ::Real=0.001f0) function cosine_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0, ϵ::Real=0.001f0)
α_bar(t) = cos((t + ϵ) / (1 + ϵ) * π / 2)^2 α̅(t) = cos((t / T + ϵ) / (1 + ϵ) * π / 2)^2
β = Float32[] β = Vector{Real}(undef, T)
for t in 1:T for t in 1:T
t1 = (t - 1) / T αₜ = α̅(t) / α̅(t - 1)
t2 = t / T
β_t = 1 - α_bar(t2) / α_bar(t1) β= 1 - αₜ
β_t = min(β_max, β_t) β= min(βₘₐₓ, βₜ)
push!(β, β_t) β[t] = βₜ
end end
return β return β
end end
""" """
Rescale betas to have zero terminal SNR. Rescale betas to have zero terminal Signal to Noise Ratio (SNR).
cf. [[2305.08891] Rescaling Diffusion Models](https://arxiv.org/abs/2305.08891) (Algorithm 1)
## Input ## Input
* β (`AbstractArray`): β_t values at each timestep t * `β::AbstractArray`: βₜ values at each timestep t
## Output ## Output
* β (`Vector{Real}`): rescaled β_t values at each timestep t * `β::Vector{Real}`: rescaled βₜ values at each timestep t
## References
* [[2305.08891] Rescaling Diffusion Models](https://arxiv.org/abs/2305.08891) (Alg. 1)
""" """
function rescale_zero_terminal_snr(β::AbstractArray) function rescale_zero_terminal_snr(β::AbstractArray)
# convert β to sqrt_α_cumprods # convert β to ⎷α̅
α = 1 .- β α = 1 .- β
α_cumprod = cumprod(α) α̅ = cumprod(α)
sqrt_α_cumprods = sqrt.(α_cumprod) ⎷α̅ = sqrt.(α̅)
# store old extrema values # store old extrema values
sqrt_α_cumprod_1 = sqrt_α_cumprods[1] ⎷α̅₁ = ⎷α̅[1]
sqrt_α_cumprod_T = sqrt_α_cumprods[end] ⎷α̅₋₁ = ⎷α̅[end]
# shift last timestep to zero # shift last timestep to zero
sqrt_α_cumprods .-= sqrt_α_cumprod_T ⎷α̅ .-= ⎷α̅₋₁
# scale so that first timestep reaches old values # scale so that first timestep reaches old values
sqrt_α_cumprods *= sqrt_α_cumprod_1 / (sqrt_α_cumprod_1 - sqrt_α_cumprod_T) ⎷α̅ *= ⎷α̅₁ / (⎷α̅₁ - ⎷α̅₋₁)
# convert back sqrt_α_cumprods to β # convert back ⎷α̅ to β
α_cumprod = sqrt_α_cumprods .^ 2 α̅ = ⎷α̅ .^ 2
α = α_cumprod[2:end] ./ α_cumprod[1:end-1] α = α̅[2:end] ./ α̅[1:end-1]
α = vcat(α_cumprod[1], α) α = vcat(α̅[1], α)
β = 1 .- α β = 1 .- α
return β return β

View file

@ -3,39 +3,65 @@ include("Schedulers.jl")
""" """
Denoising Diffusion Probabilistic Models (DDPM) scheduler. Denoising Diffusion Probabilistic Models (DDPM) scheduler.
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) ## References
* [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
""" """
struct DDPM{V<:AbstractVector} <: Scheduler struct DDPM{V<:AbstractVector} <: Scheduler
# number of diffusion steps used to train the model. T::Integer # length of markov chain
T_train::Int
# the betas used for the diffusion steps β::V # beta variance schedule
β::V α::V # 1 - beta
# internal variables used for computation (derived from β) ⎷α::V # square root of α
α::V ⎷β::V # square root of β
α_cumprods::V
α_cumprod_prevs::V α̅::V # cumulative product of α
sqrt_α_cumprods::V β̅::V # 1 - α̅ (≠ cumprod(β))
sqrt_one_minus_α_cumprods::V
α̅₋₁::V # right-shifted α̅
β̅₋₁::V # 1 - α̅₋₁
⎷α̅::V # square root of α̅
⎷β̅::V # square root of β̅
⎷α̅₋₁::V # square root of α̅₋₁
⎷β̅₋₁::V # square root of β̅₋₁
end end
function DDPM(V::DataType, β::AbstractVector) function DDPM(V::DataType, β::AbstractVector)
α = 1 .- β T = length(β)
α_cumprods = cumprod(α)
α_cumprod_prevs = [1, (α_cumprods[1:end-1])...]
sqrt_α_cumprods = sqrt.(α_cumprods) α = 1 .- β
sqrt_one_minus_α_cumprods = sqrt.(1 .- α_cumprods)
⎷α = sqrt.(α)
⎷β = sqrt.(β)
α̅ = cumprod(α)
β̅ = 1 .- α̅
α̅₋₁ = [1, (α̅[1:end-1])...]
β̅₋₁ = 1 .- α̅₋₁
⎷α̅ = sqrt.(α̅)
⎷β̅ = sqrt.(β̅)
⎷α̅₋₁ = sqrt.(α̅₋₁)
⎷β̅₋₁ = sqrt.(β̅₋₁)
DDPM{V}( DDPM{V}(
length(β), T,
β, β,
α, α,
α_cumprods, ⎷α,
α_cumprod_prevs, ⎷β,
sqrt_α_cumprods, α̅,
sqrt_one_minus_α_cumprods, β̅,
α̅₋₁,
β̅₋₁,
⎷α̅,
⎷β̅,
⎷α̅₋₁,
⎷β̅₋₁,
) )
end end
@ -43,47 +69,47 @@ end
Remove noise from model output using the backward diffusion process. Remove noise from model output using the backward diffusion process.
## Input ## Input
* scheduler (`DDPM`): scheduler to use * `scheduler::DDPM`: scheduler to use
* sample (`AbstractArray`): sample to remove noise from, i.e. model_input * `xₜ::AbstractArray`: sample to be denoised
* model_output (`AbstractArray`): predicted noise from the model * `ϵᵧ::AbstractArray`: predicted noise to remove
* timesteps (`AbstractArray`): timesteps to remove noise from * `t::AbstractArray`: timestep t of `xₜ`
## Output ## Output
* pred_prev_sample (`AbstractArray`): denoised sample at t=t-1 * `xₜ₋₁::AbstractArray`: denoised sample at t=t-1
* x_0_pred (`AbstractArray`): denoised sample at t=0 * `x̂₀::AbstractArray`: denoised sample at t=0
""" """
function step( function step(
scheduler::DDPM, scheduler::DDPM,
sample::AbstractArray, xₜ::AbstractArray,
model_output::AbstractArray, ϵᵧ::AbstractArray,
timesteps::AbstractArray, t::AbstractArray,
) )
# 1. compute alphas, betas # retreive scheduler variables at timesteps t
α_cumprod_t = scheduler.α_cumprods[timesteps] βₜ = scheduler.β[t]
α_cumprod_t_prev = scheduler.α_cumprods[timesteps.-1] β̅ₜ = scheduler.β̅[t]
β_cumprod_t = 1 .- α_cumprod_t β̅ₜ₋₁ = scheduler.β̅₋₁[t]
β_cumprod_t_prev = 1 .- α_cumprod_t_prev ⎷αₜ = scheduler.⎷α[t]
current_α_t = α_cumprod_t ./ α_cumprod_t_prev ⎷α̅ₜ = scheduler.⎷α̅[t]
current_β_t = 1 .- current_α_t ⎷α̅ₜ₋₁ = scheduler.⎷α̅₋₁[t]
⎷β̅ₜ = scheduler.⎷β̅[t]
# 2. compute predicted original sample from predicted noise also called # compute predicted previous sample x̂₀
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # arxiv:2006.11239 Eq. 15
# epsilon prediction type # arxiv:2208.11970 Eq. 115
# print shapes of thingies x̂₀ = (xₜ - ⎷β̅ₜ' .* ϵᵧ) ./ ⎷α̅ₜ'
x_0_pred = (sample - sqrt.(β_cumprod_t)' .* model_output) ./ sqrt.(α_cumprod_t)'
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # compute predicted previous sample μ̃ₜ
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # arxiv:2006.11239 Eq. 7
pred_original_sample_coeff = (sqrt.(α_cumprod_t_prev) .* current_β_t) ./ β_cumprod_t # arxiv:2208.11970 Eq. 84
current_sample_coeff = sqrt.(current_α_t) .* β_cumprod_t_prev ./ β_cumprod_t λ₀ = ⎷α̅ₜ₋₁ .* βₜ ./ β̅ₜ
λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ # TODO: this could be stored in the scheduler
μ̃ₜ = λ₀' .* x̂₀ + λₜ' .* xₜ
# 5. Compute predicted previous sample µ_t # sample predicted previous sample xₜ₋₁
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # arxiv:2006.11239 Eq. 6
pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample # arxiv:2208.11970 Eq. 70
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler
xₜ₋₁ = μ̃ₜ + σₜ' .* randn(size(ϵᵧ))
# 6. Add noise return xₜ₋₁, x̂₀
variance = sqrt.(scheduler.β[timesteps])' .* randn(size(model_output))
pred_prev_sample = pred_prev_sample + variance
return pred_prev_sample, x_0_pred
end end

View file

@ -6,5 +6,6 @@ include("BetaSchedulers.jl")
# concrete types # concrete types
include("DDPM.jl") include("DDPM.jl")
# include("DDIM.jl")
end # module Diffusers end # module Diffusers

View file

@ -6,24 +6,24 @@ Add noise to clean data using the forward diffusion process.
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) (Eq. 4) cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) (Eq. 4)
## Input ## Input
* scheduler (`Scheduler`): scheduler to use * `scheduler::Scheduler`: scheduler to use
* clean_data (`AbstractArray`): clean data to add noise to * `clean_data::AbstractArray`: clean data to add noise to
* noise (`AbstractArray`): noise to add to clean data * `noise::AbstractArray`: noise to add to clean data
* timesteps (`AbstractArray`): timesteps used to weight the noise * `timesteps::AbstractArray`: timesteps used to weight the noise
## Output ## Output
* noisy_data (`AbstractArray`): noisy data at the given timesteps * `noisy_data::AbstractArray`: noisy data at the given timesteps
""" """
function add_noise( function add_noise(
scheduler::Scheduler, scheduler::Scheduler,
clean_data::AbstractArray, x₀::AbstractArray,
noise::AbstractArray, ϵ::AbstractArray,
timesteps::AbstractArray, t::AbstractArray,
) )
sqrt_α_cumprod_t = scheduler.sqrt_α_cumprods[timesteps] ⎷α̅ₜ = scheduler.⎷α̅[t]
sqrt_one_minus_α_cumprod_t = scheduler.sqrt_one_minus_α_cumprods[timesteps] ⎷β̅ₜ = scheduler.⎷β̅[t]
noisy_data = sqrt_α_cumprod_t' .* clean_data + sqrt_one_minus_α_cumprod_t' .* noise xₜ = ⎷α̅ₜ' .* x₀ + ⎷β̅ₜ' .* ϵ
return noisy_data return xₜ
end end