Diffusers.jl/src/BetaSchedules/ZeroSNR.jl

37 lines
863 B
Julia
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Rescale betas to have zero terminal Signal to Noise Ratio (SNR).
## Input
* `β::AbstractArray`: βₜ values at each timestep t
## Output
* `β::Vector{Real}`: rescaled βₜ values at each timestep t
## References
* [[2305.08891] Rescaling Diffusion Models](https://arxiv.org/abs/2305.08891) (Alg. 1)
"""
function rescale_zero_terminal_snr(β::AbstractArray)
# convert β to ⎷α̅
α = 1 .- β
α̅ = cumprod(α)
⎷α̅ = sqrt.(α̅)
# store old extrema values
⎷α̅₁ = ⎷α̅[1]
⎷α̅₋₁ = ⎷α̅[end]
# shift last timestep to zero
⎷α̅ .-= ⎷α̅₋₁
# scale so that first timestep reaches old values
⎷α̅ *= ⎷α̅₁ / (⎷α̅₁ - ⎷α̅₋₁)
# convert back ⎷α̅ to β
α̅ = ⎷α̅ .^ 2
α = α̅[2:end] ./ α̅[1:end-1]
α = vcat(α̅[1], α)
β = 1 .- α
return β
end