mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-09-18 18:45:28 +00:00
✨ (BetaSchedules) add Exponential schedule, from god knows where wth
This commit is contained in:
parent
2a0b019deb
commit
845ce78fa0
|
@ -9,11 +9,13 @@ T = 1000
|
||||||
β_scaled_linear = scaled_linear_beta_schedule(T)
|
β_scaled_linear = scaled_linear_beta_schedule(T)
|
||||||
β_cosine = cosine_beta_schedule(T)
|
β_cosine = cosine_beta_schedule(T)
|
||||||
β_sigmoid = sigmoid_beta_schedule(T)
|
β_sigmoid = sigmoid_beta_schedule(T)
|
||||||
|
β_exponential = exponential_beta_schedule(T)
|
||||||
|
|
||||||
α̅_linear = cumprod(1 .- β_linear)
|
α̅_linear = cumprod(1 .- β_linear)
|
||||||
α̅_scaled_linear = cumprod(1 .- β_scaled_linear)
|
α̅_scaled_linear = cumprod(1 .- β_scaled_linear)
|
||||||
α̅_cosine = cumprod(1 .- β_cosine)
|
α̅_cosine = cumprod(1 .- β_cosine)
|
||||||
α̅_sigmoid = cumprod(1 .- β_sigmoid)
|
α̅_sigmoid = cumprod(1 .- β_sigmoid)
|
||||||
|
α̅_exponential = cumprod(1 .- β_exponential)
|
||||||
|
|
||||||
p1 = plot(
|
p1 = plot(
|
||||||
[
|
[
|
||||||
|
@ -21,6 +23,7 @@ p1 = plot(
|
||||||
scatter(y=β_scaled_linear, name="Scaled linear", visible="legendonly"),
|
scatter(y=β_scaled_linear, name="Scaled linear", visible="legendonly"),
|
||||||
scatter(y=β_cosine, name="Cosine"),
|
scatter(y=β_cosine, name="Cosine"),
|
||||||
scatter(y=β_sigmoid, name="Sigmoid", visible="legendonly"),
|
scatter(y=β_sigmoid, name="Sigmoid", visible="legendonly"),
|
||||||
|
scatter(y=β_exponential, name="Exponential", visible="legendonly"),
|
||||||
],
|
],
|
||||||
Layout(
|
Layout(
|
||||||
updatemenus=[
|
updatemenus=[
|
||||||
|
@ -57,6 +60,7 @@ p2 = plot(
|
||||||
scatter(y=α̅_scaled_linear, name="Scaled linear", visible="legendonly"),
|
scatter(y=α̅_scaled_linear, name="Scaled linear", visible="legendonly"),
|
||||||
scatter(y=α̅_cosine, name="Cosine"),
|
scatter(y=α̅_cosine, name="Cosine"),
|
||||||
scatter(y=α̅_sigmoid, name="Sigmoid", visible="legendonly"),
|
scatter(y=α̅_sigmoid, name="Sigmoid", visible="legendonly"),
|
||||||
|
scatter(y=α̅_exponential, name="Exponential", visible="legendonly"),
|
||||||
],
|
],
|
||||||
Layout(
|
Layout(
|
||||||
updatemenus=[
|
updatemenus=[
|
||||||
|
@ -96,7 +100,6 @@ nothing
|
||||||
<object type="text/html" data="alpha_bar_schedules.html" style="width:100%;height:420px;"></object>
|
<object type="text/html" data="alpha_bar_schedules.html" style="width:100%;height:420px;"></object>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
```@autodocs
|
```@autodocs
|
||||||
Modules = [Diffusers.BetaSchedules]
|
Modules = [Diffusers.BetaSchedules]
|
||||||
```
|
```
|
||||||
|
|
|
@ -5,6 +5,7 @@ include("Linear.jl")
|
||||||
include("ScaledLinear.jl")
|
include("ScaledLinear.jl")
|
||||||
include("Cosine.jl")
|
include("Cosine.jl")
|
||||||
include("Sigmoid.jl")
|
include("Sigmoid.jl")
|
||||||
|
include("Exponential.jl")
|
||||||
|
|
||||||
# utils
|
# utils
|
||||||
include("ZeroSNR.jl")
|
include("ZeroSNR.jl")
|
||||||
|
@ -15,6 +16,7 @@ export
|
||||||
scaled_linear_beta_schedule,
|
scaled_linear_beta_schedule,
|
||||||
cosine_beta_schedule,
|
cosine_beta_schedule,
|
||||||
sigmoid_beta_schedule,
|
sigmoid_beta_schedule,
|
||||||
|
exponential_beta_schedule,
|
||||||
|
|
||||||
# Beta Schedule utils
|
# Beta Schedule utils
|
||||||
rescale_zero_terminal_snr
|
rescale_zero_terminal_snr
|
||||||
|
|
27
src/BetaSchedules/Exponential.jl
Normal file
27
src/BetaSchedules/Exponential.jl
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
"""
|
||||||
|
Exponential beta schedule.
|
||||||
|
|
||||||
|
## Input
|
||||||
|
* `T::Int`: number of timesteps
|
||||||
|
* `βₘₐₓ::Real=0.999f0`: maximum value of β
|
||||||
|
|
||||||
|
## Output
|
||||||
|
* `β::Vector{Real}`: βₜ values at each timestep t
|
||||||
|
|
||||||
|
## References
|
||||||
|
"""
|
||||||
|
function exponential_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0)
|
||||||
|
α̅(t) = exp(-12 * t / T)
|
||||||
|
|
||||||
|
β = Vector{Real}(undef, T)
|
||||||
|
for t in 1:T
|
||||||
|
αₜ = α̅(t) / α̅(t - 1)
|
||||||
|
|
||||||
|
βₜ = 1 - αₜ
|
||||||
|
βₜ = min(βₘₐₓ, βₜ)
|
||||||
|
|
||||||
|
β[t] = βₜ
|
||||||
|
end
|
||||||
|
|
||||||
|
return β
|
||||||
|
end
|
|
@ -8,6 +8,7 @@ import .BetaSchedules:
|
||||||
scaled_linear_beta_schedule,
|
scaled_linear_beta_schedule,
|
||||||
cosine_beta_schedule,
|
cosine_beta_schedule,
|
||||||
sigmoid_beta_schedule,
|
sigmoid_beta_schedule,
|
||||||
|
exponential_beta_schedule,
|
||||||
|
|
||||||
# Beta Schedule utils
|
# Beta Schedule utils
|
||||||
rescale_zero_terminal_snr
|
rescale_zero_terminal_snr
|
||||||
|
@ -18,6 +19,7 @@ export
|
||||||
scaled_linear_beta_schedule,
|
scaled_linear_beta_schedule,
|
||||||
cosine_beta_schedule,
|
cosine_beta_schedule,
|
||||||
sigmoid_beta_schedule,
|
sigmoid_beta_schedule,
|
||||||
|
exponential_beta_schedule,
|
||||||
|
|
||||||
# Beta Schedule utils
|
# Beta Schedule utils
|
||||||
rescale_zero_terminal_snr
|
rescale_zero_terminal_snr
|
||||||
|
|
|
@ -32,13 +32,13 @@ struct DDPM{V<:AbstractVector} <: Scheduler
|
||||||
β̅::V # 1 - α̅ (≠ cumprod(β))
|
β̅::V # 1 - α̅ (≠ cumprod(β))
|
||||||
|
|
||||||
α̅₋₁::V # right-shifted α̅
|
α̅₋₁::V # right-shifted α̅
|
||||||
β̅₋₁::V # 1 - α̅₋₁
|
β̅₋₁::V # right-shifted β̅
|
||||||
|
|
||||||
⎷α̅::V # square root of α̅
|
⎷α̅::V # square root of α̅
|
||||||
⎷β̅::V # square root of β̅
|
⎷β̅::V # square root of β̅
|
||||||
|
|
||||||
⎷α̅₋₁::V # square root of α̅₋₁
|
⎷α̅₋₁::V # right-shifted ⎷α̅
|
||||||
⎷β̅₋₁::V # square root of β̅₋₁
|
⎷β̅₋₁::V # right-shifted ⎷β̅
|
||||||
end
|
end
|
||||||
|
|
||||||
function DDPM(β::AbstractVector)
|
function DDPM(β::AbstractVector)
|
||||||
|
|
Loading…
Reference in a new issue