(docs) use Literate.jl to showcase examples

This commit is contained in:
Laureηt 2023-10-06 16:58:44 +00:00
parent 90036c12b7
commit 1eec2cbf3e
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
6 changed files with 226 additions and 4 deletions

1
.gitignore vendored
View file

@ -1,6 +1,7 @@
.direnv
*.gif
*.jld2
generated/
# https://github.com/github/gitignore/blob/main/Julia.gitignore
# Files generated by invoking Julia with --code-coverage

View file

@ -1,6 +1,12 @@
[deps]
DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef"
Diffusers = "90edb7a8-79d7-49b2-b6b1-9322c3fdead8"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"

View file

@ -1,6 +1,9 @@
using Diffusers
using Documenter
using DocumenterCitations
using Literate
Literate.markdown(joinpath(@__DIR__, "..", "examples", "beta_schedulers_comparison.jl"), joinpath(@__DIR__, "src", "generated"))
DocMeta.setdocmeta!(Diffusers, :DocTestSetup, :(using Diffusers); recursive=true)
bib = CitationBibliography(joinpath(@__DIR__, "src", "refs.bib"))
@ -18,8 +21,13 @@ makedocs(bib;
linkcheck=true,
pages=[
"Home" => "index.md",
"Schedulers" => "schedulers.md",
"Beta Schedules" => "beta_schedules.md",
"API" => [
"Schedulers" => "schedulers.md",
"Beta Schedules" => "beta_schedules.md",
],
"Examples" => [
"Beta Schedules Comparison" => "generated/beta_schedulers_comparison.md",
],
"References" => "references.md",
]
)

View file

@ -33,3 +33,30 @@
archiveprefix = {arXiv},
primaryclass = {cs.LG}
}
@misc{lin2023common,
title = {Common Diffusion Noise Schedules and Sample Steps are Flawed},
author = {Shanchuan Lin and Bingchen Liu and Jiashi Li and Xiao Yang},
year = {2023},
eprint = {2305.08891},
archiveprefix = {arXiv},
primaryclass = {cs.CV}
}
@misc{xu2022geodiff,
title = {GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation},
author = {Minkai Xu and Lantao Yu and Yang Song and Chence Shi and Stefano Ermon and Jian Tang},
year = {2022},
eprint = {2203.02923},
archiveprefix = {arXiv},
primaryclass = {cs.LG}
}
@misc{salimans2022progressive,
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
author = {Tim Salimans and Jonathan Ho},
year = {2022},
eprint = {2202.00512},
archiveprefix = {arXiv},
primaryclass = {cs.LG}
}

View file

@ -0,0 +1,180 @@
# This example compares the different beta schedules available in Diffusers.jl.
# Code related to the generation of the datasets and the plots is hidden.
using Diffusers.Schedulers: DDPM, forward # hide
using Diffusers.BetaSchedules # hide
using ProgressMeter # hide
using LaTeXStrings # hide
using Random # hide
using Plots # hide
using Flux # hide
using MLDatasets # hide
function normalize_zero_to_one(x) # hide
x_min, x_max = extrema(x) # hide
x_norm = (x .- x_min) ./ (x_max - x_min) # hide
x_norm # hide
end # hide
function normalize_neg_one_to_one(x) # hide
2 * normalize_zero_to_one(x) .- 1 # hide
end # hide
num_timesteps = 100 # hide
beta_schedules = [ # hide
linear_beta_schedule, # hide
scaled_linear_beta_schedule, # hide
cosine_beta_schedule, # hide
sigmoid_beta_schedule, # hide
exponential_beta_schedule, # hide
] # hide
schedulers = [ # hide
DDPM(collect(schedule(num_timesteps))) for schedule in beta_schedules # hide
]; # hide
# ## Swiss Roll
function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π) # hide
t = rand(typeof(t_min), n_samples) * (t_max - t_min) .+ t_min # hide
x = t .* cos.(t) # hide
y = t .* sin.(t) # hide
permutedims([x y], (2, 1)) # hide
end # hide
n_points = 1000; # hide
dataset = make_spiral(n_points, 1.5f0 * π, 4.5f0 * π); # hide
dataset = normalize_neg_one_to_one(dataset); # hide
noise = randn(Float32, size(dataset)) # hide
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
plots = [] # hide
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
if t == 0 # hide
scatter(dataset[1, :], dataset[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
legend=false, # hide
) # hide
plot = scatter!(dataset[1, :], dataset[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
) # hide
title!(string(beta_schedule)) # hide
xlims!(-3, 3) # hide
ylims!(-3, 3) # hide
else # hide
scatter(dataset[1, :], dataset[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
legend=false, # hide
) # hide
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
plot = scatter!(noisy_data[1, :], noisy_data[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
) # hide
title!(string(beta_schedule)) # hide
xlims!(-3, 3) # hide
ylims!(-3, 3) # hide
end # hide
push!(plots, plot) # hide
end # hide
plot(plots...; size=(1200, 800)) # hide
end # hide
gif(anim, anim.dir * ".gif", fps=20) # hide
# ## Double Square
function make_square(n_samples::Integer=1000) # hide
x = rand(n_samples) .* 2 .- 1 # hide
y = rand(n_samples) .* 2 .- 1 # hide
p = permutedims([x y], (2, 1)) # hide
p ./ maximum(abs.(p), dims=1) # hide
end # hide
dataset = hcat( # hide
make_square(Int(n_points / 2)) ./ 2 .- 1.5, # hide
make_square(Int(n_points / 2)) ./ 2 .+ 1.5 # hide
) # hide
noise = randn(Float32, size(dataset)) # hide
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
plots = [] # hide
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
if t == 0 # hide
scatter(dataset[1, :], dataset[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
legend=false, # hide
) # hide
plot = scatter!(dataset[1, :], dataset[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
) # hide
title!(string(beta_schedule)) # hide
xlims!(-3, 3) # hide
ylims!(-3, 3) # hide
else # hide
scatter(dataset[1, :], dataset[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
legend=false, # hide
) # hide
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
plot = scatter!(noisy_data[1, :], noisy_data[2, :], # hide
alpha=0.5, # hide
aspectratio=:equal, # hide
) # hide
title!(string(beta_schedule)) # hide
xlims!(-3, 3) # hide
ylims!(-3, 3) # hide
end # hide
push!(plots, plot) # hide
end # hide
plot(plots...; size=(1200, 800)) # hide
end # hide
gif(anim, anim.dir * ".gif", fps=20) # hide
# ## MNIST
dataset = MNIST(:test)[16].features # hide
dataset = rotl90(dataset) # hide
dataset = normalize_neg_one_to_one(dataset) # hide
noise = randn(Float32, size(dataset)) # hide
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
plots = [] # hide
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
if t == 0 # hide
plot = heatmap( # hide
dataset, # hide
c=:grayC, # hide
legend=:none, # hide
aspect_ratio=:equal, # hide
grid=false, # hide
axis=false # hide
) # hide
title!(string(beta_schedule)) # hide
else # hide
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
plot = heatmap( # hide
noisy_data, # hide
c=:grayC, # hide
legend=:none, # hide
aspect_ratio=:equal, # hide
grid=false, # hide
axis=false # hide
) # hide
title!(string(beta_schedule)) # hide
end # hide
push!(plots, plot) # hide
end # hide
plot(plots...; size=(1200, 800)) # hide
end # hide
gif(anim, anim.dir * ".gif", fps=20) # hide

View file

@ -27,7 +27,7 @@ export
include("Schedulers/Schedulers.jl")
import .Schedulers:
# Scheduler
# Schedulers
DDPM,
DDIM,
@ -51,7 +51,7 @@ import .Schedulers:
VELOCITY
export
# Scheduler
# Schedulers
DDPM,
DDIM,