mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-11-08 14:38:58 +00:00
♻️ rework docstrings and variable names
This commit is contained in:
parent
43695ffcf7
commit
f6ecfdbfda
|
@ -46,7 +46,7 @@ function Base.show(io::IO, c::ConditionalChain)
|
||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
function _big_show(io::IO, m::ConditionalChain{T}, indent::Int=0, name=nothing) where {T<:NamedTuple}
|
function _big_show(io::IO, m::ConditionalChain{T}, indent::Integer=0, name=nothing) where {T<:NamedTuple}
|
||||||
println(io, " "^indent, isnothing(name) ? "" : "$name = ", "ConditionalChain(")
|
println(io, " "^indent, isnothing(name) ? "" : "$name = ", "ConditionalChain(")
|
||||||
for k in Base.keys(m.layers)
|
for k in Base.keys(m.layers)
|
||||||
_big_show(io, m.layers[k], indent + 2, k)
|
_big_show(io, m.layers[k], indent + 2, k)
|
||||||
|
|
|
@ -7,12 +7,12 @@ end
|
||||||
Flux.@functor SinusoidalPositionEmbedding
|
Flux.@functor SinusoidalPositionEmbedding
|
||||||
Flux.trainable(emb::SinusoidalPositionEmbedding) = () # mark it as an non-trainable array
|
Flux.trainable(emb::SinusoidalPositionEmbedding) = () # mark it as an non-trainable array
|
||||||
|
|
||||||
function SinusoidalPositionEmbedding(in::Int, out::Int)
|
function SinusoidalPositionEmbedding(in::Integer, out::Integer)
|
||||||
W = make_positional_embedding(out, in)
|
W = make_positional_embedding(out, in)
|
||||||
SinusoidalPositionEmbedding(W)
|
SinusoidalPositionEmbedding(W)
|
||||||
end
|
end
|
||||||
|
|
||||||
function make_positional_embedding(dim_embedding::Int, seq_length::Int=1000; n::Int=10000)
|
function make_positional_embedding(dim_embedding::Integer, seq_length::Integer=1000; n::Integer=10000)
|
||||||
embedding = Matrix{Float32}(undef, dim_embedding, seq_length)
|
embedding = Matrix{Float32}(undef, dim_embedding, seq_length)
|
||||||
for pos in 1:seq_length
|
for pos in 1:seq_length
|
||||||
for row in 0:2:(dim_embedding-1)
|
for row in 0:2:(dim_embedding-1)
|
||||||
|
|
|
@ -8,7 +8,7 @@ using ProgressMeter
|
||||||
include("Embeddings.jl")
|
include("Embeddings.jl")
|
||||||
include("ConditionalChain.jl")
|
include("ConditionalChain.jl")
|
||||||
|
|
||||||
function make_spiral(rng::AbstractRNG, n_samples::Int=1000)
|
function make_spiral(rng::AbstractRNG, n_samples::Integer=1000)
|
||||||
t_min = 1.5π
|
t_min = 1.5π
|
||||||
t_max = 4.5π
|
t_max = 4.5π
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ function make_spiral(rng::AbstractRNG, n_samples::Int=1000)
|
||||||
permutedims([x y], (2, 1))
|
permutedims([x y], (2, 1))
|
||||||
end
|
end
|
||||||
|
|
||||||
make_spiral(n_samples::Int=1000) = make_spiral(Random.GLOBAL_RNG, n_samples)
|
make_spiral(n_samples::Integer=1000) = make_spiral(Random.GLOBAL_RNG, n_samples)
|
||||||
|
|
||||||
function normalize_zero_to_one(x)
|
function normalize_zero_to_one(x)
|
||||||
x_min, x_max = extrema(x)
|
x_min, x_max = extrema(x)
|
||||||
|
@ -42,7 +42,9 @@ scatter(data[1, :], data[2, :],
|
||||||
num_timesteps = 100
|
num_timesteps = 100
|
||||||
scheduler = Diffusers.DDPM(
|
scheduler = Diffusers.DDPM(
|
||||||
Vector{Float64},
|
Vector{Float64},
|
||||||
Diffusers.cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0),
|
Diffusers.rescale_zero_terminal_snr(
|
||||||
|
Diffusers.cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
noise = randn(size(data))
|
noise = randn(size(data))
|
||||||
|
@ -104,7 +106,7 @@ model = ConditionalChain(
|
||||||
|
|
||||||
model(data, [100])
|
model(data, [100])
|
||||||
|
|
||||||
num_epochs = 1000
|
num_epochs = 100
|
||||||
loss = Flux.Losses.mse
|
loss = Flux.Losses.mse
|
||||||
opt = Flux.setup(Adam(0.0001), model)
|
opt = Flux.setup(Adam(0.0001), model)
|
||||||
dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true);
|
dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true);
|
||||||
|
|
|
@ -3,117 +3,122 @@ import NNlib: sigmoid
|
||||||
"""
|
"""
|
||||||
Linear beta schedule.
|
Linear beta schedule.
|
||||||
|
|
||||||
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* T (`Int`): number of timesteps
|
* `T::Integer`: number of timesteps
|
||||||
* β_1 (`Real := 0.0001f0`): initial value of β
|
* `β₁::Real=0.0001f0`: initial (t=1) value of β
|
||||||
* β_T (`Real := 0.02f0`): final value of β
|
* `β₋₁::Real=0.02f0`: final (t=T) value of β
|
||||||
|
|
||||||
## Output
|
## Output
|
||||||
* β (`Vector{Real}`): β_t values at each timestep t
|
* `β::Vector{Real}`: βₜ values at each timestep t
|
||||||
|
|
||||||
|
## References
|
||||||
|
* [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
||||||
"""
|
"""
|
||||||
function linear_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0)
|
function linear_beta_schedule(T::Integer, β₁::Real=0.0001f0, β₋₁::Real=0.02f0)
|
||||||
return range(start=β_1, stop=β_T, length=T)
|
return range(start=β₁, stop=β₋₁, length=T)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Scaled linear beta schedule.
|
Scaled linear beta schedule.
|
||||||
|
|
||||||
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* T (`Int`): number of timesteps
|
* `T::Int`: number of timesteps
|
||||||
* β_1 (`Real := 0.0001f0`): initial value of β
|
* `β₁::Real=0.0001f0`: initial value of β
|
||||||
* β_T (`Real := 0.02f0`): final value of β
|
* `β₋₁::Real=0.02f0`: final value of β
|
||||||
|
|
||||||
## Output
|
## Output
|
||||||
* β (`Vector{Real}`): β_t values at each timestep t
|
* `β::Vector{Real}`: βₜ values at each timestep t
|
||||||
|
|
||||||
|
## References
|
||||||
|
* [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
||||||
"""
|
"""
|
||||||
function scaled_linear_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0)
|
function scaled_linear_beta_schedule(T::Integer, β₁::Real=0.0001f0, β₋₁::Real=0.02f0)
|
||||||
return range(start=β_1^0.5, stop=β_T^0.5, length=T) .^ 2
|
return range(start=β₁^0.5, stop=β₋₁^0.5, length=T) .^ 2
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Sigmoid beta schedule.
|
Sigmoid beta schedule.
|
||||||
|
|
||||||
cf. [[2203.02923] GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation](https://arxiv.org/abs/2203.02923)
|
|
||||||
and [github.com:MinkaiXu/GeoDiff](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca48045a2f7abfccd7f0df449e45eb6eae638/models/epsnet/diffusion.py#L57)
|
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* T (`Int`): number of timesteps
|
* `T::Int`: number of timesteps
|
||||||
* β_1 (`Real := 0.0001f0`): initial value of β
|
* `β₁::Real=0.0001f0`: initial value of β
|
||||||
* β_T (`Real := 0.02f0`): final value of β
|
* `β₋₁::Real=0.02f0`: final value of β
|
||||||
|
|
||||||
## Output
|
## Output
|
||||||
* β (`Vector{Real}`): β_t values at each timestep t
|
* `β::Vector{Real}`: βₜ values at each timestep t
|
||||||
|
|
||||||
|
## References
|
||||||
|
* [[2203.02923] GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation](https://arxiv.org/abs/2203.02923)
|
||||||
|
* [github.com:MinkaiXu/GeoDiff](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca48045a2f7abfccd7f0df449e45eb6eae638/models/epsnet/diffusion.py#L57)
|
||||||
"""
|
"""
|
||||||
function sigmoid_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0)
|
function sigmoid_beta_schedule(T::Integer, β₁::Real=0.0001f0, β₋₁::Real=0.02f0)
|
||||||
x = range(start=-6, stop=6, length=T)
|
x = range(start=-6, stop=6, length=T)
|
||||||
return sigmoid(x) .* (β_T - β_1) .+ β_1
|
return sigmoid(x) .* (β₋₁ - β₁) .+ β₁
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Cosine beta schedule.
|
Cosine beta schedule.
|
||||||
|
|
||||||
cf. [[2102.09672] Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672)
|
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* T (`Int`): number of timesteps
|
* `T::Int`: number of timesteps
|
||||||
* β_max (`Real := 0.999f0`): maximum value of β
|
* `βₘₐₓ::Real=0.999f0`: maximum value of β
|
||||||
* ϵ (`Real := 1e-3f0`): small value used to avoid division by zero
|
* `ϵ::Real=1e-3f0`: small value used to avoid division by zero
|
||||||
|
|
||||||
## Output
|
## Output
|
||||||
* β (`Vector{Real}`): β_t values at each timestep t
|
* `β::Vector{Real}`: βₜ values at each timestep t
|
||||||
|
|
||||||
|
## References
|
||||||
|
* [[2102.09672] Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672)
|
||||||
|
* [github:openai/improved-diffusion](https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L36)
|
||||||
"""
|
"""
|
||||||
function cosine_beta_schedule(T::Int, β_max::Real=0.999f0, ϵ::Real=0.001f0)
|
function cosine_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0, ϵ::Real=0.001f0)
|
||||||
α_bar(t) = cos((t + ϵ) / (1 + ϵ) * π / 2)^2
|
α̅(t) = cos((t / T + ϵ) / (1 + ϵ) * π / 2)^2
|
||||||
|
|
||||||
β = Float32[]
|
β = Vector{Real}(undef, T)
|
||||||
for t in 1:T
|
for t in 1:T
|
||||||
t1 = (t - 1) / T
|
αₜ = α̅(t) / α̅(t - 1)
|
||||||
t2 = t / T
|
|
||||||
|
|
||||||
β_t = 1 - α_bar(t2) / α_bar(t1)
|
βₜ = 1 - αₜ
|
||||||
β_t = min(β_max, β_t)
|
βₜ = min(βₘₐₓ, βₜ)
|
||||||
|
|
||||||
push!(β, β_t)
|
β[t] = βₜ
|
||||||
end
|
end
|
||||||
|
|
||||||
return β
|
return β
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Rescale betas to have zero terminal SNR.
|
Rescale betas to have zero terminal Signal to Noise Ratio (SNR).
|
||||||
|
|
||||||
cf. [[2305.08891] Rescaling Diffusion Models](https://arxiv.org/abs/2305.08891) (Algorithm 1)
|
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* β (`AbstractArray`): β_t values at each timestep t
|
* `β::AbstractArray`: βₜ values at each timestep t
|
||||||
|
|
||||||
## Output
|
## Output
|
||||||
* β (`Vector{Real}`): rescaled β_t values at each timestep t
|
* `β::Vector{Real}`: rescaled βₜ values at each timestep t
|
||||||
|
|
||||||
|
## References
|
||||||
|
* [[2305.08891] Rescaling Diffusion Models](https://arxiv.org/abs/2305.08891) (Alg. 1)
|
||||||
"""
|
"""
|
||||||
function rescale_zero_terminal_snr(β::AbstractArray)
|
function rescale_zero_terminal_snr(β::AbstractArray)
|
||||||
# convert β to sqrt_α_cumprods
|
# convert β to ⎷α̅
|
||||||
α = 1 .- β
|
α = 1 .- β
|
||||||
α_cumprod = cumprod(α)
|
α̅ = cumprod(α)
|
||||||
sqrt_α_cumprods = sqrt.(α_cumprod)
|
⎷α̅ = sqrt.(α̅)
|
||||||
|
|
||||||
# store old extrema values
|
# store old extrema values
|
||||||
sqrt_α_cumprod_1 = sqrt_α_cumprods[1]
|
⎷α̅₁ = ⎷α̅[1]
|
||||||
sqrt_α_cumprod_T = sqrt_α_cumprods[end]
|
⎷α̅₋₁ = ⎷α̅[end]
|
||||||
|
|
||||||
# shift last timestep to zero
|
# shift last timestep to zero
|
||||||
sqrt_α_cumprods .-= sqrt_α_cumprod_T
|
⎷α̅ .-= ⎷α̅₋₁
|
||||||
|
|
||||||
# scale so that first timestep reaches old values
|
# scale so that first timestep reaches old values
|
||||||
sqrt_α_cumprods *= sqrt_α_cumprod_1 / (sqrt_α_cumprod_1 - sqrt_α_cumprod_T)
|
⎷α̅ *= ⎷α̅₁ / (⎷α̅₁ - ⎷α̅₋₁)
|
||||||
|
|
||||||
# convert back sqrt_α_cumprods to β
|
# convert back ⎷α̅ to β
|
||||||
α_cumprod = sqrt_α_cumprods .^ 2
|
α̅ = ⎷α̅ .^ 2
|
||||||
α = α_cumprod[2:end] ./ α_cumprod[1:end-1]
|
α = α̅[2:end] ./ α̅[1:end-1]
|
||||||
α = vcat(α_cumprod[1], α)
|
α = vcat(α̅[1], α)
|
||||||
β = 1 .- α
|
β = 1 .- α
|
||||||
|
|
||||||
return β
|
return β
|
||||||
|
|
134
src/DDPM.jl
134
src/DDPM.jl
|
@ -3,39 +3,65 @@ include("Schedulers.jl")
|
||||||
"""
|
"""
|
||||||
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
|
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
|
||||||
|
|
||||||
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
## References
|
||||||
|
* [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
||||||
"""
|
"""
|
||||||
struct DDPM{V<:AbstractVector} <: Scheduler
|
struct DDPM{V<:AbstractVector} <: Scheduler
|
||||||
# number of diffusion steps used to train the model.
|
T::Integer # length of markov chain
|
||||||
T_train::Int
|
|
||||||
|
|
||||||
# the betas used for the diffusion steps
|
β::V # beta variance schedule
|
||||||
β::V
|
α::V # 1 - beta
|
||||||
|
|
||||||
# internal variables used for computation (derived from β)
|
⎷α::V # square root of α
|
||||||
α::V
|
⎷β::V # square root of β
|
||||||
α_cumprods::V
|
|
||||||
α_cumprod_prevs::V
|
α̅::V # cumulative product of α
|
||||||
sqrt_α_cumprods::V
|
β̅::V # 1 - α̅ (≠ cumprod(β))
|
||||||
sqrt_one_minus_α_cumprods::V
|
|
||||||
|
α̅₋₁::V # right-shifted α̅
|
||||||
|
β̅₋₁::V # 1 - α̅₋₁
|
||||||
|
|
||||||
|
⎷α̅::V # square root of α̅
|
||||||
|
⎷β̅::V # square root of β̅
|
||||||
|
|
||||||
|
⎷α̅₋₁::V # square root of α̅₋₁
|
||||||
|
⎷β̅₋₁::V # square root of β̅₋₁
|
||||||
end
|
end
|
||||||
|
|
||||||
function DDPM(V::DataType, β::AbstractVector)
|
function DDPM(V::DataType, β::AbstractVector)
|
||||||
α = 1 .- β
|
T = length(β)
|
||||||
α_cumprods = cumprod(α)
|
|
||||||
α_cumprod_prevs = [1, (α_cumprods[1:end-1])...]
|
|
||||||
|
|
||||||
sqrt_α_cumprods = sqrt.(α_cumprods)
|
α = 1 .- β
|
||||||
sqrt_one_minus_α_cumprods = sqrt.(1 .- α_cumprods)
|
|
||||||
|
⎷α = sqrt.(α)
|
||||||
|
⎷β = sqrt.(β)
|
||||||
|
|
||||||
|
α̅ = cumprod(α)
|
||||||
|
β̅ = 1 .- α̅
|
||||||
|
|
||||||
|
α̅₋₁ = [1, (α̅[1:end-1])...]
|
||||||
|
β̅₋₁ = 1 .- α̅₋₁
|
||||||
|
|
||||||
|
⎷α̅ = sqrt.(α̅)
|
||||||
|
⎷β̅ = sqrt.(β̅)
|
||||||
|
|
||||||
|
⎷α̅₋₁ = sqrt.(α̅₋₁)
|
||||||
|
⎷β̅₋₁ = sqrt.(β̅₋₁)
|
||||||
|
|
||||||
DDPM{V}(
|
DDPM{V}(
|
||||||
length(β),
|
T,
|
||||||
β,
|
β,
|
||||||
α,
|
α,
|
||||||
α_cumprods,
|
⎷α,
|
||||||
α_cumprod_prevs,
|
⎷β,
|
||||||
sqrt_α_cumprods,
|
α̅,
|
||||||
sqrt_one_minus_α_cumprods,
|
β̅,
|
||||||
|
α̅₋₁,
|
||||||
|
β̅₋₁,
|
||||||
|
⎷α̅,
|
||||||
|
⎷β̅,
|
||||||
|
⎷α̅₋₁,
|
||||||
|
⎷β̅₋₁,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -43,47 +69,47 @@ end
|
||||||
Remove noise from model output using the backward diffusion process.
|
Remove noise from model output using the backward diffusion process.
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* scheduler (`DDPM`): scheduler to use
|
* `scheduler::DDPM`: scheduler to use
|
||||||
* sample (`AbstractArray`): sample to remove noise from, i.e. model_input
|
* `xₜ::AbstractArray`: sample to be denoised
|
||||||
* model_output (`AbstractArray`): predicted noise from the model
|
* `ϵᵧ::AbstractArray`: predicted noise to remove
|
||||||
* timesteps (`AbstractArray`): timesteps to remove noise from
|
* `t::AbstractArray`: timestep t of `xₜ`
|
||||||
|
|
||||||
## Output
|
## Output
|
||||||
* pred_prev_sample (`AbstractArray`): denoised sample at t=t-1
|
* `xₜ₋₁::AbstractArray`: denoised sample at t=t-1
|
||||||
* x_0_pred (`AbstractArray`): denoised sample at t=0
|
* `x̂₀::AbstractArray`: denoised sample at t=0
|
||||||
"""
|
"""
|
||||||
function step(
|
function step(
|
||||||
scheduler::DDPM,
|
scheduler::DDPM,
|
||||||
sample::AbstractArray,
|
xₜ::AbstractArray,
|
||||||
model_output::AbstractArray,
|
ϵᵧ::AbstractArray,
|
||||||
timesteps::AbstractArray,
|
t::AbstractArray,
|
||||||
)
|
)
|
||||||
# 1. compute alphas, betas
|
# retreive scheduler variables at timesteps t
|
||||||
α_cumprod_t = scheduler.α_cumprods[timesteps]
|
βₜ = scheduler.β[t]
|
||||||
α_cumprod_t_prev = scheduler.α_cumprods[timesteps.-1]
|
β̅ₜ = scheduler.β̅[t]
|
||||||
β_cumprod_t = 1 .- α_cumprod_t
|
β̅ₜ₋₁ = scheduler.β̅₋₁[t]
|
||||||
β_cumprod_t_prev = 1 .- α_cumprod_t_prev
|
⎷αₜ = scheduler.⎷α[t]
|
||||||
current_α_t = α_cumprod_t ./ α_cumprod_t_prev
|
⎷α̅ₜ = scheduler.⎷α̅[t]
|
||||||
current_β_t = 1 .- current_α_t
|
⎷α̅ₜ₋₁ = scheduler.⎷α̅₋₁[t]
|
||||||
|
⎷β̅ₜ = scheduler.⎷β̅[t]
|
||||||
|
|
||||||
# 2. compute predicted original sample from predicted noise also called
|
# compute predicted previous sample x̂₀
|
||||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
# arxiv:2006.11239 Eq. 15
|
||||||
# epsilon prediction type
|
# arxiv:2208.11970 Eq. 115
|
||||||
# print shapes of thingies
|
x̂₀ = (xₜ - ⎷β̅ₜ' .* ϵᵧ) ./ ⎷α̅ₜ'
|
||||||
x_0_pred = (sample - sqrt.(β_cumprod_t)' .* model_output) ./ sqrt.(α_cumprod_t)'
|
|
||||||
|
|
||||||
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
# compute predicted previous sample μ̃ₜ
|
||||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
# arxiv:2006.11239 Eq. 7
|
||||||
pred_original_sample_coeff = (sqrt.(α_cumprod_t_prev) .* current_β_t) ./ β_cumprod_t
|
# arxiv:2208.11970 Eq. 84
|
||||||
current_sample_coeff = sqrt.(current_α_t) .* β_cumprod_t_prev ./ β_cumprod_t
|
λ₀ = ⎷α̅ₜ₋₁ .* βₜ ./ β̅ₜ
|
||||||
|
λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ # TODO: this could be stored in the scheduler
|
||||||
|
μ̃ₜ = λ₀' .* x̂₀ + λₜ' .* xₜ
|
||||||
|
|
||||||
# 5. Compute predicted previous sample µ_t
|
# sample predicted previous sample xₜ₋₁
|
||||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
# arxiv:2006.11239 Eq. 6
|
||||||
pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample
|
# arxiv:2208.11970 Eq. 70
|
||||||
|
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler
|
||||||
|
xₜ₋₁ = μ̃ₜ + σₜ' .* randn(size(ϵᵧ))
|
||||||
|
|
||||||
# 6. Add noise
|
return xₜ₋₁, x̂₀
|
||||||
variance = sqrt.(scheduler.β[timesteps])' .* randn(size(model_output))
|
|
||||||
pred_prev_sample = pred_prev_sample + variance
|
|
||||||
|
|
||||||
return pred_prev_sample, x_0_pred
|
|
||||||
end
|
end
|
||||||
|
|
|
@ -6,5 +6,6 @@ include("BetaSchedulers.jl")
|
||||||
|
|
||||||
# concrete types
|
# concrete types
|
||||||
include("DDPM.jl")
|
include("DDPM.jl")
|
||||||
|
# include("DDIM.jl")
|
||||||
|
|
||||||
end # module Diffusers
|
end # module Diffusers
|
||||||
|
|
|
@ -6,24 +6,24 @@ Add noise to clean data using the forward diffusion process.
|
||||||
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) (Eq. 4)
|
cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) (Eq. 4)
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* scheduler (`Scheduler`): scheduler to use
|
* `scheduler::Scheduler`: scheduler to use
|
||||||
* clean_data (`AbstractArray`): clean data to add noise to
|
* `clean_data::AbstractArray`: clean data to add noise to
|
||||||
* noise (`AbstractArray`): noise to add to clean data
|
* `noise::AbstractArray`: noise to add to clean data
|
||||||
* timesteps (`AbstractArray`): timesteps used to weight the noise
|
* `timesteps::AbstractArray`: timesteps used to weight the noise
|
||||||
|
|
||||||
## Output
|
## Output
|
||||||
* noisy_data (`AbstractArray`): noisy data at the given timesteps
|
* `noisy_data::AbstractArray`: noisy data at the given timesteps
|
||||||
"""
|
"""
|
||||||
function add_noise(
|
function add_noise(
|
||||||
scheduler::Scheduler,
|
scheduler::Scheduler,
|
||||||
clean_data::AbstractArray,
|
x₀::AbstractArray,
|
||||||
noise::AbstractArray,
|
ϵ::AbstractArray,
|
||||||
timesteps::AbstractArray,
|
t::AbstractArray,
|
||||||
)
|
)
|
||||||
sqrt_α_cumprod_t = scheduler.sqrt_α_cumprods[timesteps]
|
⎷α̅ₜ = scheduler.⎷α̅[t]
|
||||||
sqrt_one_minus_α_cumprod_t = scheduler.sqrt_one_minus_α_cumprods[timesteps]
|
⎷β̅ₜ = scheduler.⎷β̅[t]
|
||||||
|
|
||||||
noisy_data = sqrt_α_cumprod_t' .* clean_data + sqrt_one_minus_α_cumprod_t' .* noise
|
xₜ = ⎷α̅ₜ' .* x₀ + ⎷β̅ₜ' .* ϵ
|
||||||
|
|
||||||
return noisy_data
|
return xₜ
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in a new issue