diff --git a/docs/src/beta_schedules.md b/docs/src/beta_schedules.md index c90dda9..fc0f5dd 100644 --- a/docs/src/beta_schedules.md +++ b/docs/src/beta_schedules.md @@ -9,11 +9,13 @@ T = 1000 β_scaled_linear = scaled_linear_beta_schedule(T) β_cosine = cosine_beta_schedule(T) β_sigmoid = sigmoid_beta_schedule(T) +β_exponential = exponential_beta_schedule(T) α̅_linear = cumprod(1 .- β_linear) α̅_scaled_linear = cumprod(1 .- β_scaled_linear) α̅_cosine = cumprod(1 .- β_cosine) α̅_sigmoid = cumprod(1 .- β_sigmoid) +α̅_exponential = cumprod(1 .- β_exponential) p1 = plot( [ @@ -21,6 +23,7 @@ p1 = plot( scatter(y=β_scaled_linear, name="Scaled linear", visible="legendonly"), scatter(y=β_cosine, name="Cosine"), scatter(y=β_sigmoid, name="Sigmoid", visible="legendonly"), + scatter(y=β_exponential, name="Exponential", visible="legendonly"), ], Layout( updatemenus=[ @@ -57,6 +60,7 @@ p2 = plot( scatter(y=α̅_scaled_linear, name="Scaled linear", visible="legendonly"), scatter(y=α̅_cosine, name="Cosine"), scatter(y=α̅_sigmoid, name="Sigmoid", visible="legendonly"), + scatter(y=α̅_exponential, name="Exponential", visible="legendonly"), ], Layout( updatemenus=[ @@ -96,7 +100,6 @@ nothing ``` - ```@autodocs Modules = [Diffusers.BetaSchedules] ``` diff --git a/src/BetaSchedules/BetaSchedules.jl b/src/BetaSchedules/BetaSchedules.jl index ce953a7..52b674d 100644 --- a/src/BetaSchedules/BetaSchedules.jl +++ b/src/BetaSchedules/BetaSchedules.jl @@ -5,6 +5,7 @@ include("Linear.jl") include("ScaledLinear.jl") include("Cosine.jl") include("Sigmoid.jl") +include("Exponential.jl") # utils include("ZeroSNR.jl") @@ -15,6 +16,7 @@ export scaled_linear_beta_schedule, cosine_beta_schedule, sigmoid_beta_schedule, + exponential_beta_schedule, # Beta Schedule utils rescale_zero_terminal_snr diff --git a/src/BetaSchedules/Exponential.jl b/src/BetaSchedules/Exponential.jl new file mode 100644 index 0000000..3d32f4e --- /dev/null +++ b/src/BetaSchedules/Exponential.jl @@ -0,0 +1,27 @@ +""" +Exponential beta schedule. + +## Input + * `T::Int`: number of timesteps + * `βₘₐₓ::Real=0.999f0`: maximum value of β + +## Output + * `β::Vector{Real}`: βₜ values at each timestep t + +## References +""" +function exponential_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0) + α̅(t) = exp(-12 * t / T) + + β = Vector{Real}(undef, T) + for t in 1:T + αₜ = α̅(t) / α̅(t - 1) + + βₜ = 1 - αₜ + βₜ = min(βₘₐₓ, βₜ) + + β[t] = βₜ + end + + return β +end diff --git a/src/Diffusers.jl b/src/Diffusers.jl index f07928c..3e9f632 100644 --- a/src/Diffusers.jl +++ b/src/Diffusers.jl @@ -8,6 +8,7 @@ import .BetaSchedules: scaled_linear_beta_schedule, cosine_beta_schedule, sigmoid_beta_schedule, + exponential_beta_schedule, # Beta Schedule utils rescale_zero_terminal_snr @@ -18,6 +19,7 @@ export scaled_linear_beta_schedule, cosine_beta_schedule, sigmoid_beta_schedule, + exponential_beta_schedule, # Beta Schedule utils rescale_zero_terminal_snr diff --git a/src/Schedulers/DDPM.jl b/src/Schedulers/DDPM.jl index 9aa05e5..2b3f870 100644 --- a/src/Schedulers/DDPM.jl +++ b/src/Schedulers/DDPM.jl @@ -32,13 +32,13 @@ struct DDPM{V<:AbstractVector} <: Scheduler β̅::V # 1 - α̅ (≠ cumprod(β)) α̅₋₁::V # right-shifted α̅ - β̅₋₁::V # 1 - α̅₋₁ + β̅₋₁::V # right-shifted β̅ ⎷α̅::V # square root of α̅ ⎷β̅::V # square root of β̅ - ⎷α̅₋₁::V # square root of α̅₋₁ - ⎷β̅₋₁::V # square root of β̅₋₁ + ⎷α̅₋₁::V # right-shifted ⎷α̅ + ⎷β̅₋₁::V # right-shifted ⎷β̅ end function DDPM(β::AbstractVector)