mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-11-09 15:02:02 +00:00
37 lines
863 B
Julia
37 lines
863 B
Julia
|
"""
|
|||
|
Rescale betas to have zero terminal Signal to Noise Ratio (SNR).
|
|||
|
|
|||
|
## Input
|
|||
|
* `β::AbstractArray`: βₜ values at each timestep t
|
|||
|
|
|||
|
## Output
|
|||
|
* `β::Vector{Real}`: rescaled βₜ values at each timestep t
|
|||
|
|
|||
|
## References
|
|||
|
* [[2305.08891] Rescaling Diffusion Models](https://arxiv.org/abs/2305.08891) (Alg. 1)
|
|||
|
"""
|
|||
|
function rescale_zero_terminal_snr(β::AbstractArray)
|
|||
|
# convert β to ⎷α̅
|
|||
|
α = 1 .- β
|
|||
|
α̅ = cumprod(α)
|
|||
|
⎷α̅ = sqrt.(α̅)
|
|||
|
|
|||
|
# store old extrema values
|
|||
|
⎷α̅₁ = ⎷α̅[1]
|
|||
|
⎷α̅₋₁ = ⎷α̅[end]
|
|||
|
|
|||
|
# shift last timestep to zero
|
|||
|
⎷α̅ .-= ⎷α̅₋₁
|
|||
|
|
|||
|
# scale so that first timestep reaches old values
|
|||
|
⎷α̅ *= ⎷α̅₁ / (⎷α̅₁ - ⎷α̅₋₁)
|
|||
|
|
|||
|
# convert back ⎷α̅ to β
|
|||
|
α̅ = ⎷α̅ .^ 2
|
|||
|
α = α̅[2:end] ./ α̅[1:end-1]
|
|||
|
α = vcat(α̅[1], α)
|
|||
|
β = 1 .- α
|
|||
|
|
|||
|
return β
|
|||
|
end
|