mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-10-22 17:46:20 +00:00
Compare commits
5 commits
0780fef335
...
1a317e4703
Author | SHA1 | Date | |
---|---|---|---|
Laureηt | 1a317e4703 | ||
Laureηt | 1eec2cbf3e | ||
Laureηt | 90036c12b7 | ||
Laureηt | ec302abc83 | ||
Laureηt | 06c85732ee |
2
.github/workflows/documentation.yml
vendored
2
.github/workflows/documentation.yml
vendored
|
@ -19,7 +19,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
version: '1.9'
|
version: '1.9'
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
|
run: julia --project=docs/ -e 'using Pkg; Pkg.instantiate()'
|
||||||
- name: Build and deploy
|
- name: Build and deploy
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,6 +1,7 @@
|
||||||
.direnv
|
.direnv
|
||||||
*.gif
|
*.gif
|
||||||
*.jld2
|
*.jld2
|
||||||
|
generated/
|
||||||
|
|
||||||
# https://github.com/github/gitignore/blob/main/Julia.gitignore
|
# https://github.com/github/gitignore/blob/main/Julia.gitignore
|
||||||
# Files generated by invoking Julia with --code-coverage
|
# Files generated by invoking Julia with --code-coverage
|
||||||
|
|
1
.vscode/settings.json
vendored
1
.vscode/settings.json
vendored
|
@ -1,5 +1,4 @@
|
||||||
{
|
{
|
||||||
"julia.useRevise": true,
|
"julia.useRevise": true,
|
||||||
"julia.persistentSession.enabled": true,
|
|
||||||
"editor.unicodeHighlight.ambiguousCharacters": false,
|
"editor.unicodeHighlight.ambiguousCharacters": false,
|
||||||
}
|
}
|
2822
docs/Manifest.toml
Normal file
2822
docs/Manifest.toml
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1,6 +1,12 @@
|
||||||
[deps]
|
[deps]
|
||||||
|
DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef"
|
||||||
Diffusers = "90edb7a8-79d7-49b2-b6b1-9322c3fdead8"
|
Diffusers = "90edb7a8-79d7-49b2-b6b1-9322c3fdead8"
|
||||||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||||
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
|
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
|
||||||
|
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||||
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
|
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
|
||||||
|
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
|
||||||
|
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
|
||||||
PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a"
|
PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a"
|
||||||
|
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
|
||||||
|
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
|
||||||
|
|
12
docs/make.jl
12
docs/make.jl
|
@ -1,6 +1,9 @@
|
||||||
using Diffusers
|
using Diffusers
|
||||||
using Documenter
|
using Documenter
|
||||||
using DocumenterCitations
|
using DocumenterCitations
|
||||||
|
using Literate
|
||||||
|
|
||||||
|
Literate.markdown(joinpath(@__DIR__, "..", "examples", "beta_schedulers_comparison.jl"), joinpath(@__DIR__, "src", "generated"))
|
||||||
|
|
||||||
DocMeta.setdocmeta!(Diffusers, :DocTestSetup, :(using Diffusers); recursive=true)
|
DocMeta.setdocmeta!(Diffusers, :DocTestSetup, :(using Diffusers); recursive=true)
|
||||||
bib = CitationBibliography(joinpath(@__DIR__, "src", "refs.bib"))
|
bib = CitationBibliography(joinpath(@__DIR__, "src", "refs.bib"))
|
||||||
|
@ -18,8 +21,13 @@ makedocs(bib;
|
||||||
linkcheck=true,
|
linkcheck=true,
|
||||||
pages=[
|
pages=[
|
||||||
"Home" => "index.md",
|
"Home" => "index.md",
|
||||||
"Schedulers" => "schedulers.md",
|
"API" => [
|
||||||
"Beta Schedules" => "beta_schedules.md",
|
"Schedulers" => "schedulers.md",
|
||||||
|
"Beta Schedules" => "beta_schedules.md",
|
||||||
|
],
|
||||||
|
"Examples" => [
|
||||||
|
"Beta Schedules Comparison" => "generated/beta_schedulers_comparison.md",
|
||||||
|
],
|
||||||
"References" => "references.md",
|
"References" => "references.md",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -33,3 +33,30 @@
|
||||||
archiveprefix = {arXiv},
|
archiveprefix = {arXiv},
|
||||||
primaryclass = {cs.LG}
|
primaryclass = {cs.LG}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@misc{lin2023common,
|
||||||
|
title = {Common Diffusion Noise Schedules and Sample Steps are Flawed},
|
||||||
|
author = {Shanchuan Lin and Bingchen Liu and Jiashi Li and Xiao Yang},
|
||||||
|
year = {2023},
|
||||||
|
eprint = {2305.08891},
|
||||||
|
archiveprefix = {arXiv},
|
||||||
|
primaryclass = {cs.CV}
|
||||||
|
}
|
||||||
|
|
||||||
|
@misc{xu2022geodiff,
|
||||||
|
title = {GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation},
|
||||||
|
author = {Minkai Xu and Lantao Yu and Yang Song and Chence Shi and Stefano Ermon and Jian Tang},
|
||||||
|
year = {2022},
|
||||||
|
eprint = {2203.02923},
|
||||||
|
archiveprefix = {arXiv},
|
||||||
|
primaryclass = {cs.LG}
|
||||||
|
}
|
||||||
|
|
||||||
|
@misc{salimans2022progressive,
|
||||||
|
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
|
||||||
|
author = {Tim Salimans and Jonathan Ho},
|
||||||
|
year = {2022},
|
||||||
|
eprint = {2202.00512},
|
||||||
|
archiveprefix = {arXiv},
|
||||||
|
primaryclass = {cs.LG}
|
||||||
|
}
|
180
examples/beta_schedulers_comparison.jl
Normal file
180
examples/beta_schedulers_comparison.jl
Normal file
|
@ -0,0 +1,180 @@
|
||||||
|
# This example compares the different beta schedules available in Diffusers.jl.
|
||||||
|
# Code related to the generation of the datasets and the plots is hidden.
|
||||||
|
|
||||||
|
using Diffusers.Schedulers: DDPM, forward # hide
|
||||||
|
using Diffusers.BetaSchedules # hide
|
||||||
|
using ProgressMeter # hide
|
||||||
|
using LaTeXStrings # hide
|
||||||
|
using Random # hide
|
||||||
|
using Plots # hide
|
||||||
|
using Flux # hide
|
||||||
|
using MLDatasets # hide
|
||||||
|
|
||||||
|
function normalize_zero_to_one(x) # hide
|
||||||
|
x_min, x_max = extrema(x) # hide
|
||||||
|
x_norm = (x .- x_min) ./ (x_max - x_min) # hide
|
||||||
|
x_norm # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
function normalize_neg_one_to_one(x) # hide
|
||||||
|
2 * normalize_zero_to_one(x) .- 1 # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
num_timesteps = 100 # hide
|
||||||
|
beta_schedules = [ # hide
|
||||||
|
linear_beta_schedule, # hide
|
||||||
|
scaled_linear_beta_schedule, # hide
|
||||||
|
cosine_beta_schedule, # hide
|
||||||
|
sigmoid_beta_schedule, # hide
|
||||||
|
exponential_beta_schedule, # hide
|
||||||
|
] # hide
|
||||||
|
schedulers = [ # hide
|
||||||
|
DDPM(collect(schedule(num_timesteps))) for schedule in beta_schedules # hide
|
||||||
|
]; # hide
|
||||||
|
|
||||||
|
# ## Swiss Roll
|
||||||
|
|
||||||
|
function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π) # hide
|
||||||
|
t = rand(typeof(t_min), n_samples) * (t_max - t_min) .+ t_min # hide
|
||||||
|
|
||||||
|
x = t .* cos.(t) # hide
|
||||||
|
y = t .* sin.(t) # hide
|
||||||
|
|
||||||
|
permutedims([x y], (2, 1)) # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
n_points = 1000; # hide
|
||||||
|
dataset = make_spiral(n_points, 1.5f0 * π, 4.5f0 * π); # hide
|
||||||
|
dataset = normalize_neg_one_to_one(dataset); # hide
|
||||||
|
|
||||||
|
noise = randn(Float32, size(dataset)) # hide
|
||||||
|
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
|
||||||
|
plots = [] # hide
|
||||||
|
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
|
||||||
|
if t == 0 # hide
|
||||||
|
scatter(dataset[1, :], dataset[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
legend=false, # hide
|
||||||
|
) # hide
|
||||||
|
plot = scatter!(dataset[1, :], dataset[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
) # hide
|
||||||
|
title!(string(beta_schedule)) # hide
|
||||||
|
xlims!(-3, 3) # hide
|
||||||
|
ylims!(-3, 3) # hide
|
||||||
|
else # hide
|
||||||
|
scatter(dataset[1, :], dataset[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
legend=false, # hide
|
||||||
|
) # hide
|
||||||
|
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
|
||||||
|
plot = scatter!(noisy_data[1, :], noisy_data[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
) # hide
|
||||||
|
title!(string(beta_schedule)) # hide
|
||||||
|
xlims!(-3, 3) # hide
|
||||||
|
ylims!(-3, 3) # hide
|
||||||
|
end # hide
|
||||||
|
push!(plots, plot) # hide
|
||||||
|
end # hide
|
||||||
|
plot(plots...; size=(1200, 800)) # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
gif(anim, anim.dir * ".gif", fps=20) # hide
|
||||||
|
|
||||||
|
# ## Double Square
|
||||||
|
|
||||||
|
function make_square(n_samples::Integer=1000) # hide
|
||||||
|
x = rand(n_samples) .* 2 .- 1 # hide
|
||||||
|
y = rand(n_samples) .* 2 .- 1 # hide
|
||||||
|
p = permutedims([x y], (2, 1)) # hide
|
||||||
|
p ./ maximum(abs.(p), dims=1) # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
dataset = hcat( # hide
|
||||||
|
make_square(Int(n_points / 2)) ./ 2 .- 1.5, # hide
|
||||||
|
make_square(Int(n_points / 2)) ./ 2 .+ 1.5 # hide
|
||||||
|
) # hide
|
||||||
|
|
||||||
|
noise = randn(Float32, size(dataset)) # hide
|
||||||
|
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
|
||||||
|
plots = [] # hide
|
||||||
|
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
|
||||||
|
if t == 0 # hide
|
||||||
|
scatter(dataset[1, :], dataset[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
legend=false, # hide
|
||||||
|
) # hide
|
||||||
|
plot = scatter!(dataset[1, :], dataset[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
) # hide
|
||||||
|
title!(string(beta_schedule)) # hide
|
||||||
|
xlims!(-3, 3) # hide
|
||||||
|
ylims!(-3, 3) # hide
|
||||||
|
else # hide
|
||||||
|
scatter(dataset[1, :], dataset[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
legend=false, # hide
|
||||||
|
) # hide
|
||||||
|
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
|
||||||
|
plot = scatter!(noisy_data[1, :], noisy_data[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
) # hide
|
||||||
|
title!(string(beta_schedule)) # hide
|
||||||
|
xlims!(-3, 3) # hide
|
||||||
|
ylims!(-3, 3) # hide
|
||||||
|
end # hide
|
||||||
|
push!(plots, plot) # hide
|
||||||
|
end # hide
|
||||||
|
plot(plots...; size=(1200, 800)) # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
gif(anim, anim.dir * ".gif", fps=20) # hide
|
||||||
|
|
||||||
|
# ## MNIST
|
||||||
|
|
||||||
|
dataset = MNIST(:test)[16].features # hide
|
||||||
|
dataset = rotl90(dataset) # hide
|
||||||
|
dataset = normalize_neg_one_to_one(dataset) # hide
|
||||||
|
|
||||||
|
noise = randn(Float32, size(dataset)) # hide
|
||||||
|
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
|
||||||
|
plots = [] # hide
|
||||||
|
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
|
||||||
|
if t == 0 # hide
|
||||||
|
plot = heatmap( # hide
|
||||||
|
dataset, # hide
|
||||||
|
c=:grayC, # hide
|
||||||
|
legend=:none, # hide
|
||||||
|
aspect_ratio=:equal, # hide
|
||||||
|
grid=false, # hide
|
||||||
|
axis=false # hide
|
||||||
|
) # hide
|
||||||
|
title!(string(beta_schedule)) # hide
|
||||||
|
else # hide
|
||||||
|
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
|
||||||
|
plot = heatmap( # hide
|
||||||
|
noisy_data, # hide
|
||||||
|
c=:grayC, # hide
|
||||||
|
legend=:none, # hide
|
||||||
|
aspect_ratio=:equal, # hide
|
||||||
|
grid=false, # hide
|
||||||
|
axis=false # hide
|
||||||
|
) # hide
|
||||||
|
title!(string(beta_schedule)) # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
push!(plots, plot) # hide
|
||||||
|
end # hide
|
||||||
|
plot(plots...; size=(1200, 800)) # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
gif(anim, anim.dir * ".gif", fps=20) # hide
|
|
@ -1,6 +1,10 @@
|
||||||
"""
|
"""
|
||||||
Cosine beta schedule.
|
Cosine beta schedule.
|
||||||
|
|
||||||
|
```math
|
||||||
|
\\overline{\\alpha}_t = \\cos \\left( \\frac{t / T + \\epsilon}{1 + \\epsilon} \\frac{\\pi}{2} \\right)
|
||||||
|
```
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* `T::Int`: number of timesteps
|
* `T::Int`: number of timesteps
|
||||||
* `βₘₐₓ::Real=0.999f0`: maximum value of β
|
* `βₘₐₓ::Real=0.999f0`: maximum value of β
|
||||||
|
@ -10,8 +14,8 @@ Cosine beta schedule.
|
||||||
* `β::Vector{Real}`: βₜ values at each timestep t
|
* `β::Vector{Real}`: βₜ values at each timestep t
|
||||||
|
|
||||||
## References
|
## References
|
||||||
* [[2102.09672] Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672)
|
* [nichol2021improved; Improved Denoising Diffusion Probabilistic Models](@cite)
|
||||||
* [github:openai/improved-diffusion](https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L36)
|
* [github:openai/improved-diffusion/improved_diffusion/gaussian_diffusion.py](https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L36)
|
||||||
"""
|
"""
|
||||||
function cosine_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0, ϵ::Real=1.0f-3)
|
function cosine_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0, ϵ::Real=1.0f-3)
|
||||||
α̅(t) = cos((t / T + ϵ) / (1 + ϵ) * π / 2)^2
|
α̅(t) = cos((t / T + ϵ) / (1 + ϵ) * π / 2)^2
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
"""
|
"""
|
||||||
Exponential beta schedule.
|
Exponential beta schedule.
|
||||||
|
|
||||||
|
```math
|
||||||
|
\\overline{\\alpha}_t = \\exp \\left( \\frac{-12 t}{T} \\right)
|
||||||
|
```
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* `T::Int`: number of timesteps
|
* `T::Int`: number of timesteps
|
||||||
* `βₘₐₓ::Real=0.999f0`: maximum value of β
|
* `βₘₐₓ::Real=0.999f0`: maximum value of β
|
||||||
|
|
||||||
## Output
|
## Output
|
||||||
* `β::Vector{Real}`: βₜ values at each timestep t
|
* `β::Vector{Real}`: βₜ values at each timestep t
|
||||||
|
|
||||||
## References
|
|
||||||
"""
|
"""
|
||||||
function exponential_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0)
|
function exponential_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0)
|
||||||
α̅(t) = exp(-12 * t / T)
|
α̅(t) = exp(-12 * t / T)
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
"""
|
"""
|
||||||
Linear beta schedule.
|
Linear beta schedule.
|
||||||
|
|
||||||
|
```math
|
||||||
|
\\beta_t = \\beta_1 + \\frac{t - 1}{T - 1} (\\beta_{-1} - \\beta_1)
|
||||||
|
```
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* `T::Integer`: number of timesteps
|
* `T::Integer`: number of timesteps
|
||||||
* `β₁::Real=1.0f-4`: initial (t=1) value of β
|
* `β₁::Real=1.0f-4`: initial (t=1) value of β
|
||||||
|
@ -10,7 +14,7 @@ Linear beta schedule.
|
||||||
* `β::Vector{Real}`: βₜ values at each timestep t
|
* `β::Vector{Real}`: βₜ values at each timestep t
|
||||||
|
|
||||||
## References
|
## References
|
||||||
* [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
* [ho2020denoising; Denoising Diffusion Probabilistic Models](@cite)
|
||||||
"""
|
"""
|
||||||
function linear_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2)
|
function linear_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2)
|
||||||
return range(start=β₁, stop=β₋₁, length=T)
|
return range(start=β₁, stop=β₋₁, length=T)
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
"""
|
"""
|
||||||
Scaled linear beta schedule.
|
Scaled linear beta schedule.
|
||||||
|
|
||||||
|
```math
|
||||||
|
\\beta_t = \\left( \\sqrt{\\beta_1} + \\frac{t - 1}{T - 1} \\left( \\sqrt{\\beta_{-1}} - \\sqrt{\\beta_1} \\right) \\right)^2
|
||||||
|
```
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* `T::Int`: number of timesteps
|
* `T::Int`: number of timesteps
|
||||||
* `β₁::Real=1.0f-4`: initial value of β
|
* `β₁::Real=1.0f-4`: initial value of β
|
||||||
|
@ -10,7 +14,7 @@ Scaled linear beta schedule.
|
||||||
* `β::Vector{Real}`: βₜ values at each timestep t
|
* `β::Vector{Real}`: βₜ values at each timestep t
|
||||||
|
|
||||||
## References
|
## References
|
||||||
* [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
* [ho2020denoising; Denoising Diffusion Probabilistic Models](@cite)
|
||||||
"""
|
"""
|
||||||
function scaled_linear_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2)
|
function scaled_linear_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2)
|
||||||
return range(start=√β₁, stop=√β₋₁, length=T) .^ 2
|
return range(start=√β₁, stop=√β₋₁, length=T) .^ 2
|
||||||
|
|
|
@ -3,6 +3,10 @@ import NNlib: sigmoid
|
||||||
"""
|
"""
|
||||||
Sigmoid beta schedule.
|
Sigmoid beta schedule.
|
||||||
|
|
||||||
|
```math
|
||||||
|
\\beta_t = \\sigma \\left( 12 \\frac{t - 1}{T - 1} - 6 \\right) ( \\beta_{-1} - \\beta_1 ) + \\beta_1
|
||||||
|
```
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* `T::Int`: number of timesteps
|
* `T::Int`: number of timesteps
|
||||||
* `β₁::Real=1.0f-4`: initial value of β
|
* `β₁::Real=1.0f-4`: initial value of β
|
||||||
|
@ -12,8 +16,8 @@ Sigmoid beta schedule.
|
||||||
* `β::Vector{Real}`: βₜ values at each timestep t
|
* `β::Vector{Real}`: βₜ values at each timestep t
|
||||||
|
|
||||||
## References
|
## References
|
||||||
* [[2203.02923] GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation](https://arxiv.org/abs/2203.02923)
|
* [xu2022geodiff; GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation](@cite)
|
||||||
* [github.com:MinkaiXu/GeoDiff](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca48045a2f7abfccd7f0df449e45eb6eae638/models/epsnet/diffusion.py#L57)
|
* [github.com:MinkaiXu/GeoDiff/models/epsnet/diffusion.py](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca48045a2f7abfccd7f0df449e45eb6eae638/models/epsnet/diffusion.py#L57)
|
||||||
"""
|
"""
|
||||||
function sigmoid_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2)
|
function sigmoid_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2)
|
||||||
x = range(start=-6, stop=6, length=T)
|
x = range(start=-6, stop=6, length=T)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
Rescale betas to have zero terminal Signal to Noise Ratio (SNR).
|
Rescale β to have zero terminal Signal to Noise Ratio (SNR).
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* `β::AbstractArray`: βₜ values at each timestep t
|
* `β::AbstractArray`: βₜ values at each timestep t
|
||||||
|
@ -8,7 +8,7 @@ Rescale betas to have zero terminal Signal to Noise Ratio (SNR).
|
||||||
* `β::Vector{Real}`: rescaled βₜ values at each timestep t
|
* `β::Vector{Real}`: rescaled βₜ values at each timestep t
|
||||||
|
|
||||||
## References
|
## References
|
||||||
* [[2305.08891] Rescaling Diffusion Models](https://arxiv.org/abs/2305.08891) (Alg. 1)
|
* [lin2023common; Rescaling Diffusion Models (Alg. 1)](@cite)
|
||||||
"""
|
"""
|
||||||
function rescale_zero_terminal_snr(β::AbstractArray)
|
function rescale_zero_terminal_snr(β::AbstractArray)
|
||||||
# convert β to ⎷α̅
|
# convert β to ⎷α̅
|
||||||
|
|
|
@ -27,7 +27,7 @@ export
|
||||||
include("Schedulers/Schedulers.jl")
|
include("Schedulers/Schedulers.jl")
|
||||||
|
|
||||||
import .Schedulers:
|
import .Schedulers:
|
||||||
# Scheduler
|
# Schedulers
|
||||||
DDPM,
|
DDPM,
|
||||||
DDIM,
|
DDIM,
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ import .Schedulers:
|
||||||
VELOCITY
|
VELOCITY
|
||||||
|
|
||||||
export
|
export
|
||||||
# Scheduler
|
# Schedulers
|
||||||
DDPM,
|
DDPM,
|
||||||
DDIM,
|
DDIM,
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
|
@enum VarianceType FIXED_SMALL FIXED_SMALL_LOG FIXED_LARGE FIXED_LARGE_LOG LEARNED
|
||||||
|
@enum PredictionType EPSILON SAMPLE VELOCITY
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Abstract type for schedulers.
|
Abstract type for schedulers.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
abstract type Scheduler end
|
abstract type Scheduler end
|
||||||
|
|
||||||
@enum VarianceType FIXED_SMALL FIXED_SMALL_LOG FIXED_LARGE FIXED_LARGE_LOG LEARNED
|
|
||||||
@enum PredictionType EPSILON SAMPLE VELOCITY
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Add noise to clean data using the forward diffusion process.
|
Add noise to clean data using the forward diffusion process.
|
||||||
|
@ -59,7 +59,7 @@ Compute the velocity of the diffusion process.
|
||||||
* `vₜ::AbstractArray`: velocity at the given timesteps
|
* `vₜ::AbstractArray`: velocity at the given timesteps
|
||||||
|
|
||||||
## References
|
## References
|
||||||
* [[2202.00512] Progressive Distillation for Fast Sampling of Diffusion Models](https://arxiv.org/abs/2202.00512) (Ann. D)
|
[salimans2022progressive; Ann. D](@cite)
|
||||||
"""
|
"""
|
||||||
function get_velocity(
|
function get_velocity(
|
||||||
scheduler::Scheduler,
|
scheduler::Scheduler,
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
include("Abstract.jl")
|
|
||||||
|
|
||||||
using ShiftedArrays
|
using ShiftedArrays
|
||||||
|
|
||||||
function _extract(
|
function _extract(
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
include("Abstract.jl")
|
|
||||||
|
|
||||||
using ShiftedArrays
|
using ShiftedArrays
|
||||||
|
|
||||||
function _extract(
|
function _extract(
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
module Schedulers
|
module Schedulers
|
||||||
|
|
||||||
|
include("Abstract.jl")
|
||||||
include("DDPM.jl")
|
include("DDPM.jl")
|
||||||
include("DDIM.jl")
|
include("DDIM.jl")
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue