mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-10-23 01:56:20 +00:00
Compare commits
No commits in common. "1a317e470372442791625722f64b30b5e3d9e28f" and "0780fef33586f63c85405e398f5544c1b29ccff0" have entirely different histories.
1a317e4703
...
0780fef335
2
.github/workflows/documentation.yml
vendored
2
.github/workflows/documentation.yml
vendored
|
@ -19,7 +19,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
version: '1.9'
|
version: '1.9'
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: julia --project=docs/ -e 'using Pkg; Pkg.instantiate()'
|
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
|
||||||
- name: Build and deploy
|
- name: Build and deploy
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,7 +1,6 @@
|
||||||
.direnv
|
.direnv
|
||||||
*.gif
|
*.gif
|
||||||
*.jld2
|
*.jld2
|
||||||
generated/
|
|
||||||
|
|
||||||
# https://github.com/github/gitignore/blob/main/Julia.gitignore
|
# https://github.com/github/gitignore/blob/main/Julia.gitignore
|
||||||
# Files generated by invoking Julia with --code-coverage
|
# Files generated by invoking Julia with --code-coverage
|
||||||
|
|
1
.vscode/settings.json
vendored
1
.vscode/settings.json
vendored
|
@ -1,4 +1,5 @@
|
||||||
{
|
{
|
||||||
"julia.useRevise": true,
|
"julia.useRevise": true,
|
||||||
|
"julia.persistentSession.enabled": true,
|
||||||
"editor.unicodeHighlight.ambiguousCharacters": false,
|
"editor.unicodeHighlight.ambiguousCharacters": false,
|
||||||
}
|
}
|
2822
docs/Manifest.toml
2822
docs/Manifest.toml
File diff suppressed because it is too large
Load diff
|
@ -1,12 +1,6 @@
|
||||||
[deps]
|
[deps]
|
||||||
DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef"
|
|
||||||
Diffusers = "90edb7a8-79d7-49b2-b6b1-9322c3fdead8"
|
Diffusers = "90edb7a8-79d7-49b2-b6b1-9322c3fdead8"
|
||||||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||||
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
|
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
|
||||||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
|
||||||
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
|
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
|
||||||
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
|
|
||||||
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
|
|
||||||
PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a"
|
PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a"
|
||||||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
|
|
||||||
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
|
|
||||||
|
|
12
docs/make.jl
12
docs/make.jl
|
@ -1,9 +1,6 @@
|
||||||
using Diffusers
|
using Diffusers
|
||||||
using Documenter
|
using Documenter
|
||||||
using DocumenterCitations
|
using DocumenterCitations
|
||||||
using Literate
|
|
||||||
|
|
||||||
Literate.markdown(joinpath(@__DIR__, "..", "examples", "beta_schedulers_comparison.jl"), joinpath(@__DIR__, "src", "generated"))
|
|
||||||
|
|
||||||
DocMeta.setdocmeta!(Diffusers, :DocTestSetup, :(using Diffusers); recursive=true)
|
DocMeta.setdocmeta!(Diffusers, :DocTestSetup, :(using Diffusers); recursive=true)
|
||||||
bib = CitationBibliography(joinpath(@__DIR__, "src", "refs.bib"))
|
bib = CitationBibliography(joinpath(@__DIR__, "src", "refs.bib"))
|
||||||
|
@ -21,13 +18,8 @@ makedocs(bib;
|
||||||
linkcheck=true,
|
linkcheck=true,
|
||||||
pages=[
|
pages=[
|
||||||
"Home" => "index.md",
|
"Home" => "index.md",
|
||||||
"API" => [
|
"Schedulers" => "schedulers.md",
|
||||||
"Schedulers" => "schedulers.md",
|
"Beta Schedules" => "beta_schedules.md",
|
||||||
"Beta Schedules" => "beta_schedules.md",
|
|
||||||
],
|
|
||||||
"Examples" => [
|
|
||||||
"Beta Schedules Comparison" => "generated/beta_schedulers_comparison.md",
|
|
||||||
],
|
|
||||||
"References" => "references.md",
|
"References" => "references.md",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -33,30 +33,3 @@
|
||||||
archiveprefix = {arXiv},
|
archiveprefix = {arXiv},
|
||||||
primaryclass = {cs.LG}
|
primaryclass = {cs.LG}
|
||||||
}
|
}
|
||||||
|
|
||||||
@misc{lin2023common,
|
|
||||||
title = {Common Diffusion Noise Schedules and Sample Steps are Flawed},
|
|
||||||
author = {Shanchuan Lin and Bingchen Liu and Jiashi Li and Xiao Yang},
|
|
||||||
year = {2023},
|
|
||||||
eprint = {2305.08891},
|
|
||||||
archiveprefix = {arXiv},
|
|
||||||
primaryclass = {cs.CV}
|
|
||||||
}
|
|
||||||
|
|
||||||
@misc{xu2022geodiff,
|
|
||||||
title = {GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation},
|
|
||||||
author = {Minkai Xu and Lantao Yu and Yang Song and Chence Shi and Stefano Ermon and Jian Tang},
|
|
||||||
year = {2022},
|
|
||||||
eprint = {2203.02923},
|
|
||||||
archiveprefix = {arXiv},
|
|
||||||
primaryclass = {cs.LG}
|
|
||||||
}
|
|
||||||
|
|
||||||
@misc{salimans2022progressive,
|
|
||||||
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
|
|
||||||
author = {Tim Salimans and Jonathan Ho},
|
|
||||||
year = {2022},
|
|
||||||
eprint = {2202.00512},
|
|
||||||
archiveprefix = {arXiv},
|
|
||||||
primaryclass = {cs.LG}
|
|
||||||
}
|
|
|
@ -1,180 +0,0 @@
|
||||||
# This example compares the different beta schedules available in Diffusers.jl.
|
|
||||||
# Code related to the generation of the datasets and the plots is hidden.
|
|
||||||
|
|
||||||
using Diffusers.Schedulers: DDPM, forward # hide
|
|
||||||
using Diffusers.BetaSchedules # hide
|
|
||||||
using ProgressMeter # hide
|
|
||||||
using LaTeXStrings # hide
|
|
||||||
using Random # hide
|
|
||||||
using Plots # hide
|
|
||||||
using Flux # hide
|
|
||||||
using MLDatasets # hide
|
|
||||||
|
|
||||||
function normalize_zero_to_one(x) # hide
|
|
||||||
x_min, x_max = extrema(x) # hide
|
|
||||||
x_norm = (x .- x_min) ./ (x_max - x_min) # hide
|
|
||||||
x_norm # hide
|
|
||||||
end # hide
|
|
||||||
|
|
||||||
function normalize_neg_one_to_one(x) # hide
|
|
||||||
2 * normalize_zero_to_one(x) .- 1 # hide
|
|
||||||
end # hide
|
|
||||||
|
|
||||||
num_timesteps = 100 # hide
|
|
||||||
beta_schedules = [ # hide
|
|
||||||
linear_beta_schedule, # hide
|
|
||||||
scaled_linear_beta_schedule, # hide
|
|
||||||
cosine_beta_schedule, # hide
|
|
||||||
sigmoid_beta_schedule, # hide
|
|
||||||
exponential_beta_schedule, # hide
|
|
||||||
] # hide
|
|
||||||
schedulers = [ # hide
|
|
||||||
DDPM(collect(schedule(num_timesteps))) for schedule in beta_schedules # hide
|
|
||||||
]; # hide
|
|
||||||
|
|
||||||
# ## Swiss Roll
|
|
||||||
|
|
||||||
function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π) # hide
|
|
||||||
t = rand(typeof(t_min), n_samples) * (t_max - t_min) .+ t_min # hide
|
|
||||||
|
|
||||||
x = t .* cos.(t) # hide
|
|
||||||
y = t .* sin.(t) # hide
|
|
||||||
|
|
||||||
permutedims([x y], (2, 1)) # hide
|
|
||||||
end # hide
|
|
||||||
|
|
||||||
n_points = 1000; # hide
|
|
||||||
dataset = make_spiral(n_points, 1.5f0 * π, 4.5f0 * π); # hide
|
|
||||||
dataset = normalize_neg_one_to_one(dataset); # hide
|
|
||||||
|
|
||||||
noise = randn(Float32, size(dataset)) # hide
|
|
||||||
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
|
|
||||||
plots = [] # hide
|
|
||||||
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
|
|
||||||
if t == 0 # hide
|
|
||||||
scatter(dataset[1, :], dataset[2, :], # hide
|
|
||||||
alpha=0.5, # hide
|
|
||||||
aspectratio=:equal, # hide
|
|
||||||
legend=false, # hide
|
|
||||||
) # hide
|
|
||||||
plot = scatter!(dataset[1, :], dataset[2, :], # hide
|
|
||||||
alpha=0.5, # hide
|
|
||||||
aspectratio=:equal, # hide
|
|
||||||
) # hide
|
|
||||||
title!(string(beta_schedule)) # hide
|
|
||||||
xlims!(-3, 3) # hide
|
|
||||||
ylims!(-3, 3) # hide
|
|
||||||
else # hide
|
|
||||||
scatter(dataset[1, :], dataset[2, :], # hide
|
|
||||||
alpha=0.5, # hide
|
|
||||||
aspectratio=:equal, # hide
|
|
||||||
legend=false, # hide
|
|
||||||
) # hide
|
|
||||||
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
|
|
||||||
plot = scatter!(noisy_data[1, :], noisy_data[2, :], # hide
|
|
||||||
alpha=0.5, # hide
|
|
||||||
aspectratio=:equal, # hide
|
|
||||||
) # hide
|
|
||||||
title!(string(beta_schedule)) # hide
|
|
||||||
xlims!(-3, 3) # hide
|
|
||||||
ylims!(-3, 3) # hide
|
|
||||||
end # hide
|
|
||||||
push!(plots, plot) # hide
|
|
||||||
end # hide
|
|
||||||
plot(plots...; size=(1200, 800)) # hide
|
|
||||||
end # hide
|
|
||||||
|
|
||||||
gif(anim, anim.dir * ".gif", fps=20) # hide
|
|
||||||
|
|
||||||
# ## Double Square
|
|
||||||
|
|
||||||
function make_square(n_samples::Integer=1000) # hide
|
|
||||||
x = rand(n_samples) .* 2 .- 1 # hide
|
|
||||||
y = rand(n_samples) .* 2 .- 1 # hide
|
|
||||||
p = permutedims([x y], (2, 1)) # hide
|
|
||||||
p ./ maximum(abs.(p), dims=1) # hide
|
|
||||||
end # hide
|
|
||||||
|
|
||||||
dataset = hcat( # hide
|
|
||||||
make_square(Int(n_points / 2)) ./ 2 .- 1.5, # hide
|
|
||||||
make_square(Int(n_points / 2)) ./ 2 .+ 1.5 # hide
|
|
||||||
) # hide
|
|
||||||
|
|
||||||
noise = randn(Float32, size(dataset)) # hide
|
|
||||||
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
|
|
||||||
plots = [] # hide
|
|
||||||
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
|
|
||||||
if t == 0 # hide
|
|
||||||
scatter(dataset[1, :], dataset[2, :], # hide
|
|
||||||
alpha=0.5, # hide
|
|
||||||
aspectratio=:equal, # hide
|
|
||||||
legend=false, # hide
|
|
||||||
) # hide
|
|
||||||
plot = scatter!(dataset[1, :], dataset[2, :], # hide
|
|
||||||
alpha=0.5, # hide
|
|
||||||
aspectratio=:equal, # hide
|
|
||||||
) # hide
|
|
||||||
title!(string(beta_schedule)) # hide
|
|
||||||
xlims!(-3, 3) # hide
|
|
||||||
ylims!(-3, 3) # hide
|
|
||||||
else # hide
|
|
||||||
scatter(dataset[1, :], dataset[2, :], # hide
|
|
||||||
alpha=0.5, # hide
|
|
||||||
aspectratio=:equal, # hide
|
|
||||||
legend=false, # hide
|
|
||||||
) # hide
|
|
||||||
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
|
|
||||||
plot = scatter!(noisy_data[1, :], noisy_data[2, :], # hide
|
|
||||||
alpha=0.5, # hide
|
|
||||||
aspectratio=:equal, # hide
|
|
||||||
) # hide
|
|
||||||
title!(string(beta_schedule)) # hide
|
|
||||||
xlims!(-3, 3) # hide
|
|
||||||
ylims!(-3, 3) # hide
|
|
||||||
end # hide
|
|
||||||
push!(plots, plot) # hide
|
|
||||||
end # hide
|
|
||||||
plot(plots...; size=(1200, 800)) # hide
|
|
||||||
end # hide
|
|
||||||
|
|
||||||
gif(anim, anim.dir * ".gif", fps=20) # hide
|
|
||||||
|
|
||||||
# ## MNIST
|
|
||||||
|
|
||||||
dataset = MNIST(:test)[16].features # hide
|
|
||||||
dataset = rotl90(dataset) # hide
|
|
||||||
dataset = normalize_neg_one_to_one(dataset) # hide
|
|
||||||
|
|
||||||
noise = randn(Float32, size(dataset)) # hide
|
|
||||||
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
|
|
||||||
plots = [] # hide
|
|
||||||
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
|
|
||||||
if t == 0 # hide
|
|
||||||
plot = heatmap( # hide
|
|
||||||
dataset, # hide
|
|
||||||
c=:grayC, # hide
|
|
||||||
legend=:none, # hide
|
|
||||||
aspect_ratio=:equal, # hide
|
|
||||||
grid=false, # hide
|
|
||||||
axis=false # hide
|
|
||||||
) # hide
|
|
||||||
title!(string(beta_schedule)) # hide
|
|
||||||
else # hide
|
|
||||||
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
|
|
||||||
plot = heatmap( # hide
|
|
||||||
noisy_data, # hide
|
|
||||||
c=:grayC, # hide
|
|
||||||
legend=:none, # hide
|
|
||||||
aspect_ratio=:equal, # hide
|
|
||||||
grid=false, # hide
|
|
||||||
axis=false # hide
|
|
||||||
) # hide
|
|
||||||
title!(string(beta_schedule)) # hide
|
|
||||||
end # hide
|
|
||||||
|
|
||||||
push!(plots, plot) # hide
|
|
||||||
end # hide
|
|
||||||
plot(plots...; size=(1200, 800)) # hide
|
|
||||||
end # hide
|
|
||||||
|
|
||||||
gif(anim, anim.dir * ".gif", fps=20) # hide
|
|
|
@ -1,10 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Cosine beta schedule.
|
Cosine beta schedule.
|
||||||
|
|
||||||
```math
|
|
||||||
\\overline{\\alpha}_t = \\cos \\left( \\frac{t / T + \\epsilon}{1 + \\epsilon} \\frac{\\pi}{2} \\right)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* `T::Int`: number of timesteps
|
* `T::Int`: number of timesteps
|
||||||
* `βₘₐₓ::Real=0.999f0`: maximum value of β
|
* `βₘₐₓ::Real=0.999f0`: maximum value of β
|
||||||
|
@ -14,8 +10,8 @@ Cosine beta schedule.
|
||||||
* `β::Vector{Real}`: βₜ values at each timestep t
|
* `β::Vector{Real}`: βₜ values at each timestep t
|
||||||
|
|
||||||
## References
|
## References
|
||||||
* [nichol2021improved; Improved Denoising Diffusion Probabilistic Models](@cite)
|
* [[2102.09672] Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672)
|
||||||
* [github:openai/improved-diffusion/improved_diffusion/gaussian_diffusion.py](https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L36)
|
* [github:openai/improved-diffusion](https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L36)
|
||||||
"""
|
"""
|
||||||
function cosine_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0, ϵ::Real=1.0f-3)
|
function cosine_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0, ϵ::Real=1.0f-3)
|
||||||
α̅(t) = cos((t / T + ϵ) / (1 + ϵ) * π / 2)^2
|
α̅(t) = cos((t / T + ϵ) / (1 + ϵ) * π / 2)^2
|
||||||
|
|
|
@ -1,16 +1,14 @@
|
||||||
"""
|
"""
|
||||||
Exponential beta schedule.
|
Exponential beta schedule.
|
||||||
|
|
||||||
```math
|
|
||||||
\\overline{\\alpha}_t = \\exp \\left( \\frac{-12 t}{T} \\right)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* `T::Int`: number of timesteps
|
* `T::Int`: number of timesteps
|
||||||
* `βₘₐₓ::Real=0.999f0`: maximum value of β
|
* `βₘₐₓ::Real=0.999f0`: maximum value of β
|
||||||
|
|
||||||
## Output
|
## Output
|
||||||
* `β::Vector{Real}`: βₜ values at each timestep t
|
* `β::Vector{Real}`: βₜ values at each timestep t
|
||||||
|
|
||||||
|
## References
|
||||||
"""
|
"""
|
||||||
function exponential_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0)
|
function exponential_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0)
|
||||||
α̅(t) = exp(-12 * t / T)
|
α̅(t) = exp(-12 * t / T)
|
||||||
|
|
|
@ -1,10 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Linear beta schedule.
|
Linear beta schedule.
|
||||||
|
|
||||||
```math
|
|
||||||
\\beta_t = \\beta_1 + \\frac{t - 1}{T - 1} (\\beta_{-1} - \\beta_1)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* `T::Integer`: number of timesteps
|
* `T::Integer`: number of timesteps
|
||||||
* `β₁::Real=1.0f-4`: initial (t=1) value of β
|
* `β₁::Real=1.0f-4`: initial (t=1) value of β
|
||||||
|
@ -14,7 +10,7 @@ Linear beta schedule.
|
||||||
* `β::Vector{Real}`: βₜ values at each timestep t
|
* `β::Vector{Real}`: βₜ values at each timestep t
|
||||||
|
|
||||||
## References
|
## References
|
||||||
* [ho2020denoising; Denoising Diffusion Probabilistic Models](@cite)
|
* [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
||||||
"""
|
"""
|
||||||
function linear_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2)
|
function linear_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2)
|
||||||
return range(start=β₁, stop=β₋₁, length=T)
|
return range(start=β₁, stop=β₋₁, length=T)
|
||||||
|
|
|
@ -1,10 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Scaled linear beta schedule.
|
Scaled linear beta schedule.
|
||||||
|
|
||||||
```math
|
|
||||||
\\beta_t = \\left( \\sqrt{\\beta_1} + \\frac{t - 1}{T - 1} \\left( \\sqrt{\\beta_{-1}} - \\sqrt{\\beta_1} \\right) \\right)^2
|
|
||||||
```
|
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* `T::Int`: number of timesteps
|
* `T::Int`: number of timesteps
|
||||||
* `β₁::Real=1.0f-4`: initial value of β
|
* `β₁::Real=1.0f-4`: initial value of β
|
||||||
|
@ -14,7 +10,7 @@ Scaled linear beta schedule.
|
||||||
* `β::Vector{Real}`: βₜ values at each timestep t
|
* `β::Vector{Real}`: βₜ values at each timestep t
|
||||||
|
|
||||||
## References
|
## References
|
||||||
* [ho2020denoising; Denoising Diffusion Probabilistic Models](@cite)
|
* [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
||||||
"""
|
"""
|
||||||
function scaled_linear_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2)
|
function scaled_linear_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2)
|
||||||
return range(start=√β₁, stop=√β₋₁, length=T) .^ 2
|
return range(start=√β₁, stop=√β₋₁, length=T) .^ 2
|
||||||
|
|
|
@ -3,10 +3,6 @@ import NNlib: sigmoid
|
||||||
"""
|
"""
|
||||||
Sigmoid beta schedule.
|
Sigmoid beta schedule.
|
||||||
|
|
||||||
```math
|
|
||||||
\\beta_t = \\sigma \\left( 12 \\frac{t - 1}{T - 1} - 6 \\right) ( \\beta_{-1} - \\beta_1 ) + \\beta_1
|
|
||||||
```
|
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* `T::Int`: number of timesteps
|
* `T::Int`: number of timesteps
|
||||||
* `β₁::Real=1.0f-4`: initial value of β
|
* `β₁::Real=1.0f-4`: initial value of β
|
||||||
|
@ -16,8 +12,8 @@ Sigmoid beta schedule.
|
||||||
* `β::Vector{Real}`: βₜ values at each timestep t
|
* `β::Vector{Real}`: βₜ values at each timestep t
|
||||||
|
|
||||||
## References
|
## References
|
||||||
* [xu2022geodiff; GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation](@cite)
|
* [[2203.02923] GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation](https://arxiv.org/abs/2203.02923)
|
||||||
* [github.com:MinkaiXu/GeoDiff/models/epsnet/diffusion.py](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca48045a2f7abfccd7f0df449e45eb6eae638/models/epsnet/diffusion.py#L57)
|
* [github.com:MinkaiXu/GeoDiff](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca48045a2f7abfccd7f0df449e45eb6eae638/models/epsnet/diffusion.py#L57)
|
||||||
"""
|
"""
|
||||||
function sigmoid_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2)
|
function sigmoid_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2)
|
||||||
x = range(start=-6, stop=6, length=T)
|
x = range(start=-6, stop=6, length=T)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
Rescale β to have zero terminal Signal to Noise Ratio (SNR).
|
Rescale betas to have zero terminal Signal to Noise Ratio (SNR).
|
||||||
|
|
||||||
## Input
|
## Input
|
||||||
* `β::AbstractArray`: βₜ values at each timestep t
|
* `β::AbstractArray`: βₜ values at each timestep t
|
||||||
|
@ -8,7 +8,7 @@ Rescale β to have zero terminal Signal to Noise Ratio (SNR).
|
||||||
* `β::Vector{Real}`: rescaled βₜ values at each timestep t
|
* `β::Vector{Real}`: rescaled βₜ values at each timestep t
|
||||||
|
|
||||||
## References
|
## References
|
||||||
* [lin2023common; Rescaling Diffusion Models (Alg. 1)](@cite)
|
* [[2305.08891] Rescaling Diffusion Models](https://arxiv.org/abs/2305.08891) (Alg. 1)
|
||||||
"""
|
"""
|
||||||
function rescale_zero_terminal_snr(β::AbstractArray)
|
function rescale_zero_terminal_snr(β::AbstractArray)
|
||||||
# convert β to ⎷α̅
|
# convert β to ⎷α̅
|
||||||
|
|
|
@ -27,7 +27,7 @@ export
|
||||||
include("Schedulers/Schedulers.jl")
|
include("Schedulers/Schedulers.jl")
|
||||||
|
|
||||||
import .Schedulers:
|
import .Schedulers:
|
||||||
# Schedulers
|
# Scheduler
|
||||||
DDPM,
|
DDPM,
|
||||||
DDIM,
|
DDIM,
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ import .Schedulers:
|
||||||
VELOCITY
|
VELOCITY
|
||||||
|
|
||||||
export
|
export
|
||||||
# Schedulers
|
# Scheduler
|
||||||
DDPM,
|
DDPM,
|
||||||
DDIM,
|
DDIM,
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
@enum VarianceType FIXED_SMALL FIXED_SMALL_LOG FIXED_LARGE FIXED_LARGE_LOG LEARNED
|
|
||||||
@enum PredictionType EPSILON SAMPLE VELOCITY
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Abstract type for schedulers.
|
Abstract type for schedulers.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
abstract type Scheduler end
|
abstract type Scheduler end
|
||||||
|
|
||||||
|
@enum VarianceType FIXED_SMALL FIXED_SMALL_LOG FIXED_LARGE FIXED_LARGE_LOG LEARNED
|
||||||
|
@enum PredictionType EPSILON SAMPLE VELOCITY
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Add noise to clean data using the forward diffusion process.
|
Add noise to clean data using the forward diffusion process.
|
||||||
|
@ -59,7 +59,7 @@ Compute the velocity of the diffusion process.
|
||||||
* `vₜ::AbstractArray`: velocity at the given timesteps
|
* `vₜ::AbstractArray`: velocity at the given timesteps
|
||||||
|
|
||||||
## References
|
## References
|
||||||
[salimans2022progressive; Ann. D](@cite)
|
* [[2202.00512] Progressive Distillation for Fast Sampling of Diffusion Models](https://arxiv.org/abs/2202.00512) (Ann. D)
|
||||||
"""
|
"""
|
||||||
function get_velocity(
|
function get_velocity(
|
||||||
scheduler::Scheduler,
|
scheduler::Scheduler,
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
include("Abstract.jl")
|
||||||
|
|
||||||
using ShiftedArrays
|
using ShiftedArrays
|
||||||
|
|
||||||
function _extract(
|
function _extract(
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
include("Abstract.jl")
|
||||||
|
|
||||||
using ShiftedArrays
|
using ShiftedArrays
|
||||||
|
|
||||||
function _extract(
|
function _extract(
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
module Schedulers
|
module Schedulers
|
||||||
|
|
||||||
include("Abstract.jl")
|
|
||||||
include("DDPM.jl")
|
include("DDPM.jl")
|
||||||
include("DDIM.jl")
|
include("DDIM.jl")
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue