mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-11-09 15:02:02 +00:00
✨ (BetaSchedulers) add rescale_zero_terminal_snr
This commit is contained in:
parent
972013637d
commit
96be7c3307
|
@ -11,7 +11,7 @@ cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/ab
|
|||
* β_T (`Real := 0.02f0`): final value of β
|
||||
|
||||
## Output
|
||||
* βs (`Vector{Real}`): β_t values at each timestep t
|
||||
* β (`Vector{Real}`): β_t values at each timestep t
|
||||
"""
|
||||
function linear_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0)
|
||||
return range(start=β_1, stop=β_T, length=T)
|
||||
|
@ -28,7 +28,7 @@ cf. [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/ab
|
|||
* β_T (`Real := 0.02f0`): final value of β
|
||||
|
||||
## Output
|
||||
* βs (`Vector{Real}`): β_t values at each timestep t
|
||||
* β (`Vector{Real}`): β_t values at each timestep t
|
||||
"""
|
||||
function scaled_linear_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0)
|
||||
return range(start=β_1^0.5, stop=β_T^0.5, length=T) .^ 2
|
||||
|
@ -46,11 +46,11 @@ and [github.com:MinkaiXu/GeoDiff](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca
|
|||
* β_T (`Real := 0.02f0`): final value of β
|
||||
|
||||
## Output
|
||||
* βs (`Vector{Real}`): β_t values at each timestep t
|
||||
* β (`Vector{Real}`): β_t values at each timestep t
|
||||
"""
|
||||
function sigmoid_beta_schedule(T::Int, β_1::Real=0.0001f0, β_T::Real=0.02f0)
|
||||
x = range(start=-6, stop=6, length=T)
|
||||
return sigmoid(x) * (β_T - β_1) + β_1
|
||||
return sigmoid(x) .* (β_T - β_1) .+ β_1
|
||||
end
|
||||
|
||||
"""
|
||||
|
@ -64,12 +64,12 @@ cf. [[2102.09672] Improved Denoising Diffusion Probabilistic Models](https://arx
|
|||
* ϵ (`Real := 1e-3f0`): small value used to avoid division by zero
|
||||
|
||||
## Output
|
||||
* βs (`Vector{Real}`): β_t values at each timestep t
|
||||
* β (`Vector{Real}`): β_t values at each timestep t
|
||||
"""
|
||||
function cosine_beta_schedule(T::Int, β_max::Real=0.999f0, ϵ::Real=1e-3f0)
|
||||
function cosine_beta_schedule(T::Int, β_max::Real=0.999f0, ϵ::Real=0.001f0)
|
||||
α_bar(t) = cos((t + ϵ) / (1 + ϵ) * π / 2)^2
|
||||
|
||||
βs = Float32[]
|
||||
β = Float32[]
|
||||
for t in 1:T
|
||||
t1 = (t - 1) / T
|
||||
t2 = t / T
|
||||
|
@ -77,8 +77,44 @@ function cosine_beta_schedule(T::Int, β_max::Real=0.999f0, ϵ::Real=1e-3f0)
|
|||
β_t = 1 - α_bar(t2) / α_bar(t1)
|
||||
β_t = min(β_max, β_t)
|
||||
|
||||
push!(βs, β_t)
|
||||
push!(β, β_t)
|
||||
end
|
||||
|
||||
return βs
|
||||
return β
|
||||
end
|
||||
|
||||
"""
|
||||
Rescale betas to have zero terminal SNR.
|
||||
|
||||
cf. [[2305.08891] Rescaling Diffusion Models](https://arxiv.org/abs/2305.08891) (Algorithm 1)
|
||||
|
||||
## Input
|
||||
* β (`AbstractArray`): β_t values at each timestep t
|
||||
|
||||
## Output
|
||||
* β (`Vector{Real}`): rescaled β_t values at each timestep t
|
||||
"""
|
||||
function rescale_zero_terminal_snr(β::AbstractArray)
|
||||
# convert β to sqrt_α_cumprods
|
||||
α = 1 .- β
|
||||
α_cumprod = cumprod(α)
|
||||
sqrt_α_cumprods = sqrt.(α_cumprod)
|
||||
|
||||
# store old extrema values
|
||||
sqrt_α_cumprod_1 = sqrt_α_cumprods[1]
|
||||
sqrt_α_cumprod_T = sqrt_α_cumprods[end]
|
||||
|
||||
# shift last timestep to zero
|
||||
sqrt_α_cumprods .-= sqrt_α_cumprod_T
|
||||
|
||||
# scale so that first timestep reaches old values
|
||||
sqrt_α_cumprods *= sqrt_α_cumprod_1 / (sqrt_α_cumprod_1 - sqrt_α_cumprod_T)
|
||||
|
||||
# convert back sqrt_α_cumprods to β
|
||||
α_cumprod = sqrt_α_cumprods .^ 2
|
||||
α = α_cumprod[2:end] ./ α_cumprod[1:end-1]
|
||||
α = vcat(α_cumprod[1], α)
|
||||
β = 1 .- α
|
||||
|
||||
return β
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue