From f6ecfdbfda80dee7c718ef0f210e94deeb4e3353 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Sat, 29 Jul 2023 15:18:42 +0200 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20rework=20docstrings=20and?= =?UTF-8?q?=20variable=20names?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/ConditionalChain.jl | 2 +- examples/Embeddings.jl | 4 +- examples/swissroll.jl | 10 +-- src/BetaSchedulers.jl | 115 ++++++++++++++++-------------- src/DDPM.jl | 134 +++++++++++++++++++++-------------- src/Diffusers.jl | 1 + src/Schedulers.jl | 24 +++---- 7 files changed, 162 insertions(+), 128 deletions(-) diff --git a/examples/ConditionalChain.jl b/examples/ConditionalChain.jl index 8b58056..0c279ee 100644 --- a/examples/ConditionalChain.jl +++ b/examples/ConditionalChain.jl @@ -46,7 +46,7 @@ function Base.show(io::IO, c::ConditionalChain) print(io, ")") end -function _big_show(io::IO, m::ConditionalChain{T}, indent::Int=0, name=nothing) where {T<:NamedTuple} +function _big_show(io::IO, m::ConditionalChain{T}, indent::Integer=0, name=nothing) where {T<:NamedTuple} println(io, " "^indent, isnothing(name) ? "" : "$name = ", "ConditionalChain(") for k in Base.keys(m.layers) _big_show(io, m.layers[k], indent + 2, k) diff --git a/examples/Embeddings.jl b/examples/Embeddings.jl index 897dbe3..746b8e6 100644 --- a/examples/Embeddings.jl +++ b/examples/Embeddings.jl @@ -7,12 +7,12 @@ end Flux.@functor SinusoidalPositionEmbedding Flux.trainable(emb::SinusoidalPositionEmbedding) = () # mark it as an non-trainable array -function SinusoidalPositionEmbedding(in::Int, out::Int) +function SinusoidalPositionEmbedding(in::Integer, out::Integer) W = make_positional_embedding(out, in) SinusoidalPositionEmbedding(W) end -function make_positional_embedding(dim_embedding::Int, seq_length::Int=1000; n::Int=10000) +function make_positional_embedding(dim_embedding::Integer, seq_length::Integer=1000; n::Integer=10000) embedding = Matrix{Float32}(undef, dim_embedding, seq_length) for pos in 1:seq_length for row in 0:2:(dim_embedding-1) diff --git a/examples/swissroll.jl b/examples/swissroll.jl index fc6d11a..0e8d7ee 100644 --- a/examples/swissroll.jl +++ b/examples/swissroll.jl @@ -8,7 +8,7 @@ using ProgressMeter include("Embeddings.jl") include("ConditionalChain.jl") -function make_spiral(rng::AbstractRNG, n_samples::Int=1000) +function make_spiral(rng::AbstractRNG, n_samples::Integer=1000) t_min = 1.5π t_max = 4.5π @@ -20,7 +20,7 @@ function make_spiral(rng::AbstractRNG, n_samples::Int=1000) permutedims([x y], (2, 1)) end -make_spiral(n_samples::Int=1000) = make_spiral(Random.GLOBAL_RNG, n_samples) +make_spiral(n_samples::Integer=1000) = make_spiral(Random.GLOBAL_RNG, n_samples) function normalize_zero_to_one(x) x_min, x_max = extrema(x) @@ -42,7 +42,9 @@ scatter(data[1, :], data[2, :], num_timesteps = 100 scheduler = Diffusers.DDPM( Vector{Float64}, - Diffusers.cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0), + Diffusers.rescale_zero_terminal_snr( + Diffusers.cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0) + ) ) noise = randn(size(data)) @@ -104,7 +106,7 @@ model = ConditionalChain( model(data, [100]) -num_epochs = 1000 +num_epochs = 100 loss = Flux.Losses.mse opt = Flux.setup(Adam(0.0001), model) dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true); diff --git a/src/BetaSchedulers.jl b/src/BetaSchedulers.jl index 7c8a3a0..e47a150 100644 --- a/src/BetaSchedulers.jl +++ b/src/BetaSchedulers.jl @@ -3,117 +3,122 @@ import NNlib: sigmoid """ Linear beta schedule. -cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) - ## Input - * T (`Int`): number of timesteps - * β_1 (`Real := 0.0001f0`): initial value of β - * β_T (`Real := 0.02f0`): final value of β + * `T::Integer`: number of timesteps + * `β₁::Real=0.0001f0`: initial (t=1) value of β + * `β₋₁::Real=0.02f0`: final (t=T) value of β ## Output - * β (`Vector{Real}`): β_t values at each timestep t + * `β::Vector{Real}`: βₜ values at each timestep t + +## References + * [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) """ -function linear_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0) - return range(start=β_1, stop=β_T, length=T) +function linear_beta_schedule(T::Integer, β₁::Real=0.0001f0, β₋₁::Real=0.02f0) + return range(start=β₁, stop=β₋₁, length=T) end """ Scaled linear beta schedule. -cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) - ## Input - * T (`Int`): number of timesteps - * β_1 (`Real := 0.0001f0`): initial value of β - * β_T (`Real := 0.02f0`): final value of β + * `T::Int`: number of timesteps + * `β₁::Real=0.0001f0`: initial value of β + * `β₋₁::Real=0.02f0`: final value of β ## Output - * β (`Vector{Real}`): β_t values at each timestep t + * `β::Vector{Real}`: βₜ values at each timestep t + +## References + * [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) """ -function scaled_linear_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0) - return range(start=β_1^0.5, stop=β_T^0.5, length=T) .^ 2 +function scaled_linear_beta_schedule(T::Integer, β₁::Real=0.0001f0, β₋₁::Real=0.02f0) + return range(start=β₁^0.5, stop=β₋₁^0.5, length=T) .^ 2 end """ Sigmoid beta schedule. -cf. [[2203.02923] GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation](https://arxiv.org/abs/2203.02923) -and [github.com:MinkaiXu/GeoDiff](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca48045a2f7abfccd7f0df449e45eb6eae638/models/epsnet/diffusion.py#L57) - ## Input - * T (`Int`): number of timesteps - * β_1 (`Real := 0.0001f0`): initial value of β - * β_T (`Real := 0.02f0`): final value of β + * `T::Int`: number of timesteps + * `β₁::Real=0.0001f0`: initial value of β + * `β₋₁::Real=0.02f0`: final value of β ## Output - * β (`Vector{Real}`): β_t values at each timestep t + * `β::Vector{Real}`: βₜ values at each timestep t + +## References + * [[2203.02923] GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation](https://arxiv.org/abs/2203.02923) + * [github.com:MinkaiXu/GeoDiff](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca48045a2f7abfccd7f0df449e45eb6eae638/models/epsnet/diffusion.py#L57) """ -function sigmoid_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0) +function sigmoid_beta_schedule(T::Integer, β₁::Real=0.0001f0, β₋₁::Real=0.02f0) x = range(start=-6, stop=6, length=T) - return sigmoid(x) .* (β_T - β_1) .+ β_1 + return sigmoid(x) .* (β₋₁ - β₁) .+ β₁ end """ Cosine beta schedule. -cf. [[2102.09672] Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672) - ## Input - * T (`Int`): number of timesteps - * β_max (`Real := 0.999f0`): maximum value of β - * ϵ (`Real := 1e-3f0`): small value used to avoid division by zero + * `T::Int`: number of timesteps + * `βₘₐₓ::Real=0.999f0`: maximum value of β + * `ϵ::Real=1e-3f0`: small value used to avoid division by zero ## Output - * β (`Vector{Real}`): β_t values at each timestep t + * `β::Vector{Real}`: βₜ values at each timestep t + +## References + * [[2102.09672] Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672) + * [github:openai/improved-diffusion](https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L36) """ -function cosine_beta_schedule(T::Int, β_max::Real=0.999f0, ϵ::Real=0.001f0) - α_bar(t) = cos((t + ϵ) / (1 + ϵ) * π / 2)^2 +function cosine_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0, ϵ::Real=0.001f0) + α̅(t) = cos((t / T + ϵ) / (1 + ϵ) * π / 2)^2 - β = Float32[] + β = Vector{Real}(undef, T) for t in 1:T - t1 = (t - 1) / T - t2 = t / T + αₜ = α̅(t) / α̅(t - 1) - β_t = 1 - α_bar(t2) / α_bar(t1) - β_t = min(β_max, β_t) + βₜ = 1 - αₜ + βₜ = min(βₘₐₓ, βₜ) - push!(β, β_t) + β[t] = βₜ end return β end """ -Rescale betas to have zero terminal SNR. - -cf. [[2305.08891] Rescaling Diffusion Models](https://arxiv.org/abs/2305.08891) (Algorithm 1) +Rescale betas to have zero terminal Signal to Noise Ratio (SNR). ## Input - * β (`AbstractArray`): β_t values at each timestep t + * `β::AbstractArray`: βₜ values at each timestep t ## Output - * β (`Vector{Real}`): rescaled β_t values at each timestep t + * `β::Vector{Real}`: rescaled βₜ values at each timestep t + +## References + * [[2305.08891] Rescaling Diffusion Models](https://arxiv.org/abs/2305.08891) (Alg. 1) """ function rescale_zero_terminal_snr(β::AbstractArray) - # convert β to sqrt_α_cumprods + # convert β to ⎷α̅ α = 1 .- β - α_cumprod = cumprod(α) - sqrt_α_cumprods = sqrt.(α_cumprod) + α̅ = cumprod(α) + ⎷α̅ = sqrt.(α̅) # store old extrema values - sqrt_α_cumprod_1 = sqrt_α_cumprods[1] - sqrt_α_cumprod_T = sqrt_α_cumprods[end] + ⎷α̅₁ = ⎷α̅[1] + ⎷α̅₋₁ = ⎷α̅[end] # shift last timestep to zero - sqrt_α_cumprods .-= sqrt_α_cumprod_T + ⎷α̅ .-= ⎷α̅₋₁ # scale so that first timestep reaches old values - sqrt_α_cumprods *= sqrt_α_cumprod_1 / (sqrt_α_cumprod_1 - sqrt_α_cumprod_T) + ⎷α̅ *= ⎷α̅₁ / (⎷α̅₁ - ⎷α̅₋₁) - # convert back sqrt_α_cumprods to β - α_cumprod = sqrt_α_cumprods .^ 2 - α = α_cumprod[2:end] ./ α_cumprod[1:end-1] - α = vcat(α_cumprod[1], α) + # convert back ⎷α̅ to β + α̅ = ⎷α̅ .^ 2 + α = α̅[2:end] ./ α̅[1:end-1] + α = vcat(α̅[1], α) β = 1 .- α return β diff --git a/src/DDPM.jl b/src/DDPM.jl index be491cf..d8aec81 100644 --- a/src/DDPM.jl +++ b/src/DDPM.jl @@ -3,39 +3,65 @@ include("Schedulers.jl") """ Denoising Diffusion Probabilistic Models (DDPM) scheduler. -cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) +## References + * [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) """ struct DDPM{V<:AbstractVector} <: Scheduler - # number of diffusion steps used to train the model. - T_train::Int + T::Integer # length of markov chain - # the betas used for the diffusion steps - β::V + β::V # beta variance schedule + α::V # 1 - beta - # internal variables used for computation (derived from β) - α::V - α_cumprods::V - α_cumprod_prevs::V - sqrt_α_cumprods::V - sqrt_one_minus_α_cumprods::V + ⎷α::V # square root of α + ⎷β::V # square root of β + + α̅::V # cumulative product of α + β̅::V # 1 - α̅ (≠ cumprod(β)) + + α̅₋₁::V # right-shifted α̅ + β̅₋₁::V # 1 - α̅₋₁ + + ⎷α̅::V # square root of α̅ + ⎷β̅::V # square root of β̅ + + ⎷α̅₋₁::V # square root of α̅₋₁ + ⎷β̅₋₁::V # square root of β̅₋₁ end function DDPM(V::DataType, β::AbstractVector) - α = 1 .- β - α_cumprods = cumprod(α) - α_cumprod_prevs = [1, (α_cumprods[1:end-1])...] + T = length(β) - sqrt_α_cumprods = sqrt.(α_cumprods) - sqrt_one_minus_α_cumprods = sqrt.(1 .- α_cumprods) + α = 1 .- β + + ⎷α = sqrt.(α) + ⎷β = sqrt.(β) + + α̅ = cumprod(α) + β̅ = 1 .- α̅ + + α̅₋₁ = [1, (α̅[1:end-1])...] + β̅₋₁ = 1 .- α̅₋₁ + + ⎷α̅ = sqrt.(α̅) + ⎷β̅ = sqrt.(β̅) + + ⎷α̅₋₁ = sqrt.(α̅₋₁) + ⎷β̅₋₁ = sqrt.(β̅₋₁) DDPM{V}( - length(β), + T, β, α, - α_cumprods, - α_cumprod_prevs, - sqrt_α_cumprods, - sqrt_one_minus_α_cumprods, + ⎷α, + ⎷β, + α̅, + β̅, + α̅₋₁, + β̅₋₁, + ⎷α̅, + ⎷β̅, + ⎷α̅₋₁, + ⎷β̅₋₁, ) end @@ -43,47 +69,47 @@ end Remove noise from model output using the backward diffusion process. ## Input - * scheduler (`DDPM`): scheduler to use - * sample (`AbstractArray`): sample to remove noise from, i.e. model_input - * model_output (`AbstractArray`): predicted noise from the model - * timesteps (`AbstractArray`): timesteps to remove noise from + * `scheduler::DDPM`: scheduler to use + * `xₜ::AbstractArray`: sample to be denoised + * `ϵᵧ::AbstractArray`: predicted noise to remove + * `t::AbstractArray`: timestep t of `xₜ` ## Output - * pred_prev_sample (`AbstractArray`): denoised sample at t=t-1 - * x_0_pred (`AbstractArray`): denoised sample at t=0 + * `xₜ₋₁::AbstractArray`: denoised sample at t=t-1 + * `x̂₀::AbstractArray`: denoised sample at t=0 """ function step( scheduler::DDPM, - sample::AbstractArray, - model_output::AbstractArray, - timesteps::AbstractArray, + xₜ::AbstractArray, + ϵᵧ::AbstractArray, + t::AbstractArray, ) - # 1. compute alphas, betas - α_cumprod_t = scheduler.α_cumprods[timesteps] - α_cumprod_t_prev = scheduler.α_cumprods[timesteps.-1] - β_cumprod_t = 1 .- α_cumprod_t - β_cumprod_t_prev = 1 .- α_cumprod_t_prev - current_α_t = α_cumprod_t ./ α_cumprod_t_prev - current_β_t = 1 .- current_α_t + # retreive scheduler variables at timesteps t + βₜ = scheduler.β[t] + β̅ₜ = scheduler.β̅[t] + β̅ₜ₋₁ = scheduler.β̅₋₁[t] + ⎷αₜ = scheduler.⎷α[t] + ⎷α̅ₜ = scheduler.⎷α̅[t] + ⎷α̅ₜ₋₁ = scheduler.⎷α̅₋₁[t] + ⎷β̅ₜ = scheduler.⎷β̅[t] - # 2. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - # epsilon prediction type - # print shapes of thingies - x_0_pred = (sample - sqrt.(β_cumprod_t)' .* model_output) ./ sqrt.(α_cumprod_t)' + # compute predicted previous sample x̂₀ + # arxiv:2006.11239 Eq. 15 + # arxiv:2208.11970 Eq. 115 + x̂₀ = (xₜ - ⎷β̅ₜ' .* ϵᵧ) ./ ⎷α̅ₜ' - # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_sample_coeff = (sqrt.(α_cumprod_t_prev) .* current_β_t) ./ β_cumprod_t - current_sample_coeff = sqrt.(current_α_t) .* β_cumprod_t_prev ./ β_cumprod_t + # compute predicted previous sample μ̃ₜ + # arxiv:2006.11239 Eq. 7 + # arxiv:2208.11970 Eq. 84 + λ₀ = ⎷α̅ₜ₋₁ .* βₜ ./ β̅ₜ + λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ # TODO: this could be stored in the scheduler + μ̃ₜ = λ₀' .* x̂₀ + λₜ' .* xₜ - # 5. Compute predicted previous sample µ_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample + # sample predicted previous sample xₜ₋₁ + # arxiv:2006.11239 Eq. 6 + # arxiv:2208.11970 Eq. 70 + σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler + xₜ₋₁ = μ̃ₜ + σₜ' .* randn(size(ϵᵧ)) - # 6. Add noise - variance = sqrt.(scheduler.β[timesteps])' .* randn(size(model_output)) - pred_prev_sample = pred_prev_sample + variance - - return pred_prev_sample, x_0_pred + return xₜ₋₁, x̂₀ end diff --git a/src/Diffusers.jl b/src/Diffusers.jl index c15a5f1..785e4fd 100644 --- a/src/Diffusers.jl +++ b/src/Diffusers.jl @@ -6,5 +6,6 @@ include("BetaSchedulers.jl") # concrete types include("DDPM.jl") +# include("DDIM.jl") end # module Diffusers diff --git a/src/Schedulers.jl b/src/Schedulers.jl index 707d21f..b78539b 100644 --- a/src/Schedulers.jl +++ b/src/Schedulers.jl @@ -6,24 +6,24 @@ Add noise to clean data using the forward diffusion process. cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) (Eq. 4) ## Input - * scheduler (`Scheduler`): scheduler to use - * clean_data (`AbstractArray`): clean data to add noise to - * noise (`AbstractArray`): noise to add to clean data - * timesteps (`AbstractArray`): timesteps used to weight the noise + * `scheduler::Scheduler`: scheduler to use + * `clean_data::AbstractArray`: clean data to add noise to + * `noise::AbstractArray`: noise to add to clean data + * `timesteps::AbstractArray`: timesteps used to weight the noise ## Output - * noisy_data (`AbstractArray`): noisy data at the given timesteps + * `noisy_data::AbstractArray`: noisy data at the given timesteps """ function add_noise( scheduler::Scheduler, - clean_data::AbstractArray, - noise::AbstractArray, - timesteps::AbstractArray, + x₀::AbstractArray, + ϵ::AbstractArray, + t::AbstractArray, ) - sqrt_α_cumprod_t = scheduler.sqrt_α_cumprods[timesteps] - sqrt_one_minus_α_cumprod_t = scheduler.sqrt_one_minus_α_cumprods[timesteps] + ⎷α̅ₜ = scheduler.⎷α̅[t] + ⎷β̅ₜ = scheduler.⎷β̅[t] - noisy_data = sqrt_α_cumprod_t' .* clean_data + sqrt_one_minus_α_cumprod_t' .* noise + xₜ = ⎷α̅ₜ' .* x₀ + ⎷β̅ₜ' .* ϵ - return noisy_data + return xₜ end