Compare commits

...

5 commits

19 changed files with 3082 additions and 24 deletions

View file

@ -19,7 +19,7 @@ jobs:
with: with:
version: '1.9' version: '1.9'
- name: Install dependencies - name: Install dependencies
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' run: julia --project=docs/ -e 'using Pkg; Pkg.instantiate()'
- name: Build and deploy - name: Build and deploy
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

1
.gitignore vendored
View file

@ -1,6 +1,7 @@
.direnv .direnv
*.gif *.gif
*.jld2 *.jld2
generated/
# https://github.com/github/gitignore/blob/main/Julia.gitignore # https://github.com/github/gitignore/blob/main/Julia.gitignore
# Files generated by invoking Julia with --code-coverage # Files generated by invoking Julia with --code-coverage

View file

@ -1,5 +1,4 @@
{ {
"julia.useRevise": true, "julia.useRevise": true,
"julia.persistentSession.enabled": true,
"editor.unicodeHighlight.ambiguousCharacters": false, "editor.unicodeHighlight.ambiguousCharacters": false,
} }

2822
docs/Manifest.toml Normal file

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,12 @@
[deps] [deps]
DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef"
Diffusers = "90edb7a8-79d7-49b2-b6b1-9322c3fdead8" Diffusers = "90edb7a8-79d7-49b2-b6b1-9322c3fdead8"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a" PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"

View file

@ -1,6 +1,9 @@
using Diffusers using Diffusers
using Documenter using Documenter
using DocumenterCitations using DocumenterCitations
using Literate
Literate.markdown(joinpath(@__DIR__, "..", "examples", "beta_schedulers_comparison.jl"), joinpath(@__DIR__, "src", "generated"))
DocMeta.setdocmeta!(Diffusers, :DocTestSetup, :(using Diffusers); recursive=true) DocMeta.setdocmeta!(Diffusers, :DocTestSetup, :(using Diffusers); recursive=true)
bib = CitationBibliography(joinpath(@__DIR__, "src", "refs.bib")) bib = CitationBibliography(joinpath(@__DIR__, "src", "refs.bib"))
@ -18,8 +21,13 @@ makedocs(bib;
linkcheck=true, linkcheck=true,
pages=[ pages=[
"Home" => "index.md", "Home" => "index.md",
"Schedulers" => "schedulers.md", "API" => [
"Beta Schedules" => "beta_schedules.md", "Schedulers" => "schedulers.md",
"Beta Schedules" => "beta_schedules.md",
],
"Examples" => [
"Beta Schedules Comparison" => "generated/beta_schedulers_comparison.md",
],
"References" => "references.md", "References" => "references.md",
] ]
) )

View file

@ -33,3 +33,30 @@
archiveprefix = {arXiv}, archiveprefix = {arXiv},
primaryclass = {cs.LG} primaryclass = {cs.LG}
} }
@misc{lin2023common,
title = {Common Diffusion Noise Schedules and Sample Steps are Flawed},
author = {Shanchuan Lin and Bingchen Liu and Jiashi Li and Xiao Yang},
year = {2023},
eprint = {2305.08891},
archiveprefix = {arXiv},
primaryclass = {cs.CV}
}
@misc{xu2022geodiff,
title = {GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation},
author = {Minkai Xu and Lantao Yu and Yang Song and Chence Shi and Stefano Ermon and Jian Tang},
year = {2022},
eprint = {2203.02923},
archiveprefix = {arXiv},
primaryclass = {cs.LG}
}
@misc{salimans2022progressive,
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
author = {Tim Salimans and Jonathan Ho},
year = {2022},
eprint = {2202.00512},
archiveprefix = {arXiv},
primaryclass = {cs.LG}
}

View file

@ -0,0 +1,180 @@
# This example compares the different beta schedules available in Diffusers.jl.
# Code related to the generation of the datasets and the plots is hidden.
using Diffusers.Schedulers: DDPM, forward # hide
using Diffusers.BetaSchedules # hide
using ProgressMeter # hide
using LaTeXStrings # hide
using Random # hide
using Plots # hide
using Flux # hide
using MLDatasets # hide
function normalize_zero_to_one(x) # hide
x_min, x_max = extrema(x) # hide
x_norm = (x .- x_min) ./ (x_max - x_min) # hide
x_norm # hide
end # hide
function normalize_neg_one_to_one(x) # hide
2 * normalize_zero_to_one(x) .- 1 # hide
end # hide
num_timesteps = 100 # hide
beta_schedules = [ # hide
linear_beta_schedule, # hide
scaled_linear_beta_schedule, # hide
cosine_beta_schedule, # hide
sigmoid_beta_schedule, # hide
exponential_beta_schedule, # hide
] # hide
schedulers = [ # hide
DDPM(collect(schedule(num_timesteps))) for schedule in beta_schedules # hide
]; # hide
# ## Swiss Roll
function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π) # hide
t = rand(typeof(t_min), n_samples) * (t_max - t_min) .+ t_min # hide
x = t .* cos.(t) # hide
y = t .* sin.(t) # hide
permutedims([x y], (2, 1)) # hide
end # hide
n_points = 1000; # hide
dataset = make_spiral(n_points, 1.5f0 * π, 4.5f0 * π); # hide
dataset = normalize_neg_one_to_one(dataset); # hide
noise = randn(Float32, size(dataset)) # hide
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
plots = [] # hide
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
if t == 0 # hide
scatter(dataset[1, :], dataset[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
legend=false, # hide
) # hide
plot = scatter!(dataset[1, :], dataset[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
) # hide
title!(string(beta_schedule)) # hide
xlims!(-3, 3) # hide
ylims!(-3, 3) # hide
else # hide
scatter(dataset[1, :], dataset[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
legend=false, # hide
) # hide
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
plot = scatter!(noisy_data[1, :], noisy_data[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
) # hide
title!(string(beta_schedule)) # hide
xlims!(-3, 3) # hide
ylims!(-3, 3) # hide
end # hide
push!(plots, plot) # hide
end # hide
plot(plots...; size=(1200, 800)) # hide
end # hide
gif(anim, anim.dir * ".gif", fps=20) # hide
# ## Double Square
function make_square(n_samples::Integer=1000) # hide
x = rand(n_samples) .* 2 .- 1 # hide
y = rand(n_samples) .* 2 .- 1 # hide
p = permutedims([x y], (2, 1)) # hide
p ./ maximum(abs.(p), dims=1) # hide
end # hide
dataset = hcat( # hide
make_square(Int(n_points / 2)) ./ 2 .- 1.5, # hide
make_square(Int(n_points / 2)) ./ 2 .+ 1.5 # hide
) # hide
noise = randn(Float32, size(dataset)) # hide
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
plots = [] # hide
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
if t == 0 # hide
scatter(dataset[1, :], dataset[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
legend=false, # hide
) # hide
plot = scatter!(dataset[1, :], dataset[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
) # hide
title!(string(beta_schedule)) # hide
xlims!(-3, 3) # hide
ylims!(-3, 3) # hide
else # hide
scatter(dataset[1, :], dataset[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
legend=false, # hide
) # hide
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
plot = scatter!(noisy_data[1, :], noisy_data[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
) # hide
title!(string(beta_schedule)) # hide
xlims!(-3, 3) # hide
ylims!(-3, 3) # hide
end # hide
push!(plots, plot) # hide
end # hide
plot(plots...; size=(1200, 800)) # hide
end # hide
gif(anim, anim.dir * ".gif", fps=20) # hide
# ## MNIST
dataset = MNIST(:test)[16].features # hide
dataset = rotl90(dataset) # hide
dataset = normalize_neg_one_to_one(dataset) # hide
noise = randn(Float32, size(dataset)) # hide
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
plots = [] # hide
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
if t == 0 # hide
plot = heatmap( # hide
dataset, # hide
c=:grayC, # hide
legend=:none, # hide
aspect_ratio=:equal, # hide
grid=false, # hide
axis=false # hide
) # hide
title!(string(beta_schedule)) # hide
else # hide
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
plot = heatmap( # hide
noisy_data, # hide
c=:grayC, # hide
legend=:none, # hide
aspect_ratio=:equal, # hide
grid=false, # hide
axis=false # hide
) # hide
title!(string(beta_schedule)) # hide
end # hide
push!(plots, plot) # hide
end # hide
plot(plots...; size=(1200, 800)) # hide
end # hide
gif(anim, anim.dir * ".gif", fps=20) # hide

View file

@ -1,6 +1,10 @@
""" """
Cosine beta schedule. Cosine beta schedule.
```math
\\overline{\\alpha}_t = \\cos \\left( \\frac{t / T + \\epsilon}{1 + \\epsilon} \\frac{\\pi}{2} \\right)
```
## Input ## Input
* `T::Int`: number of timesteps * `T::Int`: number of timesteps
* `βₘₐₓ::Real=0.999f0`: maximum value of β * `βₘₐₓ::Real=0.999f0`: maximum value of β
@ -10,8 +14,8 @@ Cosine beta schedule.
* `β::Vector{Real}`: βₜ values at each timestep t * `β::Vector{Real}`: βₜ values at each timestep t
## References ## References
* [[2102.09672] Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672) * [nichol2021improved; Improved Denoising Diffusion Probabilistic Models](@cite)
* [github:openai/improved-diffusion](https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L36) * [github:openai/improved-diffusion/improved_diffusion/gaussian_diffusion.py](https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L36)
""" """
function cosine_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0, ϵ::Real=1.0f-3) function cosine_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0, ϵ::Real=1.0f-3)
α̅(t) = cos((t / T + ϵ) / (1 + ϵ) * π / 2)^2 α̅(t) = cos((t / T + ϵ) / (1 + ϵ) * π / 2)^2

View file

@ -1,14 +1,16 @@
""" """
Exponential beta schedule. Exponential beta schedule.
```math
\\overline{\\alpha}_t = \\exp \\left( \\frac{-12 t}{T} \\right)
```
## Input ## Input
* `T::Int`: number of timesteps * `T::Int`: number of timesteps
* `βₘₐₓ::Real=0.999f0`: maximum value of β * `βₘₐₓ::Real=0.999f0`: maximum value of β
## Output ## Output
* `β::Vector{Real}`: βₜ values at each timestep t * `β::Vector{Real}`: βₜ values at each timestep t
## References
""" """
function exponential_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0) function exponential_beta_schedule(T::Integer, βₘₐₓ::Real=0.999f0)
α̅(t) = exp(-12 * t / T) α̅(t) = exp(-12 * t / T)

View file

@ -1,6 +1,10 @@
""" """
Linear beta schedule. Linear beta schedule.
```math
\\beta_t = \\beta_1 + \\frac{t - 1}{T - 1} (\\beta_{-1} - \\beta_1)
```
## Input ## Input
* `T::Integer`: number of timesteps * `T::Integer`: number of timesteps
* `β₁::Real=1.0f-4`: initial (t=1) value of β * `β₁::Real=1.0f-4`: initial (t=1) value of β
@ -10,7 +14,7 @@ Linear beta schedule.
* `β::Vector{Real}`: βₜ values at each timestep t * `β::Vector{Real}`: βₜ values at each timestep t
## References ## References
* [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) * [ho2020denoising; Denoising Diffusion Probabilistic Models](@cite)
""" """
function linear_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2) function linear_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2)
return range(start=β₁, stop=β₋₁, length=T) return range(start=β₁, stop=β₋₁, length=T)

View file

@ -1,6 +1,10 @@
""" """
Scaled linear beta schedule. Scaled linear beta schedule.
```math
\\beta_t = \\left( \\sqrt{\\beta_1} + \\frac{t - 1}{T - 1} \\left( \\sqrt{\\beta_{-1}} - \\sqrt{\\beta_1} \\right) \\right)^2
```
## Input ## Input
* `T::Int`: number of timesteps * `T::Int`: number of timesteps
* `β₁::Real=1.0f-4`: initial value of β * `β₁::Real=1.0f-4`: initial value of β
@ -10,7 +14,7 @@ Scaled linear beta schedule.
* `β::Vector{Real}`: βₜ values at each timestep t * `β::Vector{Real}`: βₜ values at each timestep t
## References ## References
* [[2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) * [ho2020denoising; Denoising Diffusion Probabilistic Models](@cite)
""" """
function scaled_linear_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2) function scaled_linear_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2)
return range(start=β₁, stop=β₋₁, length=T) .^ 2 return range(start=β₁, stop=β₋₁, length=T) .^ 2

View file

@ -3,6 +3,10 @@ import NNlib: sigmoid
""" """
Sigmoid beta schedule. Sigmoid beta schedule.
```math
\\beta_t = \\sigma \\left( 12 \\frac{t - 1}{T - 1} - 6 \\right) ( \\beta_{-1} - \\beta_1 ) + \\beta_1
```
## Input ## Input
* `T::Int`: number of timesteps * `T::Int`: number of timesteps
* `β₁::Real=1.0f-4`: initial value of β * `β₁::Real=1.0f-4`: initial value of β
@ -12,8 +16,8 @@ Sigmoid beta schedule.
* `β::Vector{Real}`: βₜ values at each timestep t * `β::Vector{Real}`: βₜ values at each timestep t
## References ## References
* [[2203.02923] GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation](https://arxiv.org/abs/2203.02923) * [xu2022geodiff; GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation](@cite)
* [github.com:MinkaiXu/GeoDiff](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca48045a2f7abfccd7f0df449e45eb6eae638/models/epsnet/diffusion.py#L57) * [github.com:MinkaiXu/GeoDiff/models/epsnet/diffusion.py](https://github.com/MinkaiXu/GeoDiff/blob/ea0ca48045a2f7abfccd7f0df449e45eb6eae638/models/epsnet/diffusion.py#L57)
""" """
function sigmoid_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2) function sigmoid_beta_schedule(T::Integer, β₁::Real=1.0f-4, β₋₁::Real=2.0f-2)
x = range(start=-6, stop=6, length=T) x = range(start=-6, stop=6, length=T)

View file

@ -1,5 +1,5 @@
""" """
Rescale betas to have zero terminal Signal to Noise Ratio (SNR). Rescale β to have zero terminal Signal to Noise Ratio (SNR).
## Input ## Input
* `β::AbstractArray`: βₜ values at each timestep t * `β::AbstractArray`: βₜ values at each timestep t
@ -8,7 +8,7 @@ Rescale betas to have zero terminal Signal to Noise Ratio (SNR).
* `β::Vector{Real}`: rescaled βₜ values at each timestep t * `β::Vector{Real}`: rescaled βₜ values at each timestep t
## References ## References
* [[2305.08891] Rescaling Diffusion Models](https://arxiv.org/abs/2305.08891) (Alg. 1) * [lin2023common; Rescaling Diffusion Models (Alg. 1)](@cite)
""" """
function rescale_zero_terminal_snr(β::AbstractArray) function rescale_zero_terminal_snr(β::AbstractArray)
# convert β to ⎷α̅ # convert β to ⎷α̅

View file

@ -27,7 +27,7 @@ export
include("Schedulers/Schedulers.jl") include("Schedulers/Schedulers.jl")
import .Schedulers: import .Schedulers:
# Scheduler # Schedulers
DDPM, DDPM,
DDIM, DDIM,
@ -51,7 +51,7 @@ import .Schedulers:
VELOCITY VELOCITY
export export
# Scheduler # Schedulers
DDPM, DDPM,
DDIM, DDIM,

View file

@ -1,11 +1,11 @@
@enum VarianceType FIXED_SMALL FIXED_SMALL_LOG FIXED_LARGE FIXED_LARGE_LOG LEARNED
@enum PredictionType EPSILON SAMPLE VELOCITY
""" """
Abstract type for schedulers. Abstract type for schedulers.
""" """
abstract type Scheduler end abstract type Scheduler end
@enum VarianceType FIXED_SMALL FIXED_SMALL_LOG FIXED_LARGE FIXED_LARGE_LOG LEARNED
@enum PredictionType EPSILON SAMPLE VELOCITY
""" """
Add noise to clean data using the forward diffusion process. Add noise to clean data using the forward diffusion process.
@ -59,7 +59,7 @@ Compute the velocity of the diffusion process.
* `vₜ::AbstractArray`: velocity at the given timesteps * `vₜ::AbstractArray`: velocity at the given timesteps
## References ## References
* [[2202.00512] Progressive Distillation for Fast Sampling of Diffusion Models](https://arxiv.org/abs/2202.00512) (Ann. D) [salimans2022progressive; Ann. D](@cite)
""" """
function get_velocity( function get_velocity(
scheduler::Scheduler, scheduler::Scheduler,

View file

@ -1,5 +1,3 @@
include("Abstract.jl")
using ShiftedArrays using ShiftedArrays
function _extract( function _extract(

View file

@ -1,5 +1,3 @@
include("Abstract.jl")
using ShiftedArrays using ShiftedArrays
function _extract( function _extract(

View file

@ -1,5 +1,6 @@
module Schedulers module Schedulers
include("Abstract.jl")
include("DDPM.jl") include("DDPM.jl")
include("DDIM.jl") include("DDIM.jl")