mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-09-16 17:45:29 +00:00
✨ (docs) use Literate.jl to showcase examples
This commit is contained in:
parent
90036c12b7
commit
1eec2cbf3e
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,6 +1,7 @@
|
||||||
.direnv
|
.direnv
|
||||||
*.gif
|
*.gif
|
||||||
*.jld2
|
*.jld2
|
||||||
|
generated/
|
||||||
|
|
||||||
# https://github.com/github/gitignore/blob/main/Julia.gitignore
|
# https://github.com/github/gitignore/blob/main/Julia.gitignore
|
||||||
# Files generated by invoking Julia with --code-coverage
|
# Files generated by invoking Julia with --code-coverage
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
[deps]
|
[deps]
|
||||||
|
DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef"
|
||||||
Diffusers = "90edb7a8-79d7-49b2-b6b1-9322c3fdead8"
|
Diffusers = "90edb7a8-79d7-49b2-b6b1-9322c3fdead8"
|
||||||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
|
||||||
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
|
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
|
||||||
|
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||||
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
|
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
|
||||||
|
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
|
||||||
|
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
|
||||||
PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a"
|
PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a"
|
||||||
|
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
|
||||||
|
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
|
||||||
|
|
12
docs/make.jl
12
docs/make.jl
|
@ -1,6 +1,9 @@
|
||||||
using Diffusers
|
using Diffusers
|
||||||
using Documenter
|
using Documenter
|
||||||
using DocumenterCitations
|
using DocumenterCitations
|
||||||
|
using Literate
|
||||||
|
|
||||||
|
Literate.markdown(joinpath(@__DIR__, "..", "examples", "beta_schedulers_comparison.jl"), joinpath(@__DIR__, "src", "generated"))
|
||||||
|
|
||||||
DocMeta.setdocmeta!(Diffusers, :DocTestSetup, :(using Diffusers); recursive=true)
|
DocMeta.setdocmeta!(Diffusers, :DocTestSetup, :(using Diffusers); recursive=true)
|
||||||
bib = CitationBibliography(joinpath(@__DIR__, "src", "refs.bib"))
|
bib = CitationBibliography(joinpath(@__DIR__, "src", "refs.bib"))
|
||||||
|
@ -18,8 +21,13 @@ makedocs(bib;
|
||||||
linkcheck=true,
|
linkcheck=true,
|
||||||
pages=[
|
pages=[
|
||||||
"Home" => "index.md",
|
"Home" => "index.md",
|
||||||
"Schedulers" => "schedulers.md",
|
"API" => [
|
||||||
"Beta Schedules" => "beta_schedules.md",
|
"Schedulers" => "schedulers.md",
|
||||||
|
"Beta Schedules" => "beta_schedules.md",
|
||||||
|
],
|
||||||
|
"Examples" => [
|
||||||
|
"Beta Schedules Comparison" => "generated/beta_schedulers_comparison.md",
|
||||||
|
],
|
||||||
"References" => "references.md",
|
"References" => "references.md",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -33,3 +33,30 @@
|
||||||
archiveprefix = {arXiv},
|
archiveprefix = {arXiv},
|
||||||
primaryclass = {cs.LG}
|
primaryclass = {cs.LG}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@misc{lin2023common,
|
||||||
|
title = {Common Diffusion Noise Schedules and Sample Steps are Flawed},
|
||||||
|
author = {Shanchuan Lin and Bingchen Liu and Jiashi Li and Xiao Yang},
|
||||||
|
year = {2023},
|
||||||
|
eprint = {2305.08891},
|
||||||
|
archiveprefix = {arXiv},
|
||||||
|
primaryclass = {cs.CV}
|
||||||
|
}
|
||||||
|
|
||||||
|
@misc{xu2022geodiff,
|
||||||
|
title = {GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation},
|
||||||
|
author = {Minkai Xu and Lantao Yu and Yang Song and Chence Shi and Stefano Ermon and Jian Tang},
|
||||||
|
year = {2022},
|
||||||
|
eprint = {2203.02923},
|
||||||
|
archiveprefix = {arXiv},
|
||||||
|
primaryclass = {cs.LG}
|
||||||
|
}
|
||||||
|
|
||||||
|
@misc{salimans2022progressive,
|
||||||
|
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
|
||||||
|
author = {Tim Salimans and Jonathan Ho},
|
||||||
|
year = {2022},
|
||||||
|
eprint = {2202.00512},
|
||||||
|
archiveprefix = {arXiv},
|
||||||
|
primaryclass = {cs.LG}
|
||||||
|
}
|
180
examples/beta_schedulers_comparison.jl
Normal file
180
examples/beta_schedulers_comparison.jl
Normal file
|
@ -0,0 +1,180 @@
|
||||||
|
# This example compares the different beta schedules available in Diffusers.jl.
|
||||||
|
# Code related to the generation of the datasets and the plots is hidden.
|
||||||
|
|
||||||
|
using Diffusers.Schedulers: DDPM, forward # hide
|
||||||
|
using Diffusers.BetaSchedules # hide
|
||||||
|
using ProgressMeter # hide
|
||||||
|
using LaTeXStrings # hide
|
||||||
|
using Random # hide
|
||||||
|
using Plots # hide
|
||||||
|
using Flux # hide
|
||||||
|
using MLDatasets # hide
|
||||||
|
|
||||||
|
function normalize_zero_to_one(x) # hide
|
||||||
|
x_min, x_max = extrema(x) # hide
|
||||||
|
x_norm = (x .- x_min) ./ (x_max - x_min) # hide
|
||||||
|
x_norm # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
function normalize_neg_one_to_one(x) # hide
|
||||||
|
2 * normalize_zero_to_one(x) .- 1 # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
num_timesteps = 100 # hide
|
||||||
|
beta_schedules = [ # hide
|
||||||
|
linear_beta_schedule, # hide
|
||||||
|
scaled_linear_beta_schedule, # hide
|
||||||
|
cosine_beta_schedule, # hide
|
||||||
|
sigmoid_beta_schedule, # hide
|
||||||
|
exponential_beta_schedule, # hide
|
||||||
|
] # hide
|
||||||
|
schedulers = [ # hide
|
||||||
|
DDPM(collect(schedule(num_timesteps))) for schedule in beta_schedules # hide
|
||||||
|
]; # hide
|
||||||
|
|
||||||
|
# ## Swiss Roll
|
||||||
|
|
||||||
|
function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π) # hide
|
||||||
|
t = rand(typeof(t_min), n_samples) * (t_max - t_min) .+ t_min # hide
|
||||||
|
|
||||||
|
x = t .* cos.(t) # hide
|
||||||
|
y = t .* sin.(t) # hide
|
||||||
|
|
||||||
|
permutedims([x y], (2, 1)) # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
n_points = 1000; # hide
|
||||||
|
dataset = make_spiral(n_points, 1.5f0 * π, 4.5f0 * π); # hide
|
||||||
|
dataset = normalize_neg_one_to_one(dataset); # hide
|
||||||
|
|
||||||
|
noise = randn(Float32, size(dataset)) # hide
|
||||||
|
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
|
||||||
|
plots = [] # hide
|
||||||
|
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
|
||||||
|
if t == 0 # hide
|
||||||
|
scatter(dataset[1, :], dataset[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
legend=false, # hide
|
||||||
|
) # hide
|
||||||
|
plot = scatter!(dataset[1, :], dataset[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
) # hide
|
||||||
|
title!(string(beta_schedule)) # hide
|
||||||
|
xlims!(-3, 3) # hide
|
||||||
|
ylims!(-3, 3) # hide
|
||||||
|
else # hide
|
||||||
|
scatter(dataset[1, :], dataset[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
legend=false, # hide
|
||||||
|
) # hide
|
||||||
|
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
|
||||||
|
plot = scatter!(noisy_data[1, :], noisy_data[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
) # hide
|
||||||
|
title!(string(beta_schedule)) # hide
|
||||||
|
xlims!(-3, 3) # hide
|
||||||
|
ylims!(-3, 3) # hide
|
||||||
|
end # hide
|
||||||
|
push!(plots, plot) # hide
|
||||||
|
end # hide
|
||||||
|
plot(plots...; size=(1200, 800)) # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
gif(anim, anim.dir * ".gif", fps=20) # hide
|
||||||
|
|
||||||
|
# ## Double Square
|
||||||
|
|
||||||
|
function make_square(n_samples::Integer=1000) # hide
|
||||||
|
x = rand(n_samples) .* 2 .- 1 # hide
|
||||||
|
y = rand(n_samples) .* 2 .- 1 # hide
|
||||||
|
p = permutedims([x y], (2, 1)) # hide
|
||||||
|
p ./ maximum(abs.(p), dims=1) # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
dataset = hcat( # hide
|
||||||
|
make_square(Int(n_points / 2)) ./ 2 .- 1.5, # hide
|
||||||
|
make_square(Int(n_points / 2)) ./ 2 .+ 1.5 # hide
|
||||||
|
) # hide
|
||||||
|
|
||||||
|
noise = randn(Float32, size(dataset)) # hide
|
||||||
|
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
|
||||||
|
plots = [] # hide
|
||||||
|
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
|
||||||
|
if t == 0 # hide
|
||||||
|
scatter(dataset[1, :], dataset[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
legend=false, # hide
|
||||||
|
) # hide
|
||||||
|
plot = scatter!(dataset[1, :], dataset[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
) # hide
|
||||||
|
title!(string(beta_schedule)) # hide
|
||||||
|
xlims!(-3, 3) # hide
|
||||||
|
ylims!(-3, 3) # hide
|
||||||
|
else # hide
|
||||||
|
scatter(dataset[1, :], dataset[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
legend=false, # hide
|
||||||
|
) # hide
|
||||||
|
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
|
||||||
|
plot = scatter!(noisy_data[1, :], noisy_data[2, :], # hide
|
||||||
|
alpha=0.5, # hide
|
||||||
|
aspectratio=:equal, # hide
|
||||||
|
) # hide
|
||||||
|
title!(string(beta_schedule)) # hide
|
||||||
|
xlims!(-3, 3) # hide
|
||||||
|
ylims!(-3, 3) # hide
|
||||||
|
end # hide
|
||||||
|
push!(plots, plot) # hide
|
||||||
|
end # hide
|
||||||
|
plot(plots...; size=(1200, 800)) # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
gif(anim, anim.dir * ".gif", fps=20) # hide
|
||||||
|
|
||||||
|
# ## MNIST
|
||||||
|
|
||||||
|
dataset = MNIST(:test)[16].features # hide
|
||||||
|
dataset = rotl90(dataset) # hide
|
||||||
|
dataset = normalize_neg_one_to_one(dataset) # hide
|
||||||
|
|
||||||
|
noise = randn(Float32, size(dataset)) # hide
|
||||||
|
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
|
||||||
|
plots = [] # hide
|
||||||
|
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
|
||||||
|
if t == 0 # hide
|
||||||
|
plot = heatmap( # hide
|
||||||
|
dataset, # hide
|
||||||
|
c=:grayC, # hide
|
||||||
|
legend=:none, # hide
|
||||||
|
aspect_ratio=:equal, # hide
|
||||||
|
grid=false, # hide
|
||||||
|
axis=false # hide
|
||||||
|
) # hide
|
||||||
|
title!(string(beta_schedule)) # hide
|
||||||
|
else # hide
|
||||||
|
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
|
||||||
|
plot = heatmap( # hide
|
||||||
|
noisy_data, # hide
|
||||||
|
c=:grayC, # hide
|
||||||
|
legend=:none, # hide
|
||||||
|
aspect_ratio=:equal, # hide
|
||||||
|
grid=false, # hide
|
||||||
|
axis=false # hide
|
||||||
|
) # hide
|
||||||
|
title!(string(beta_schedule)) # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
push!(plots, plot) # hide
|
||||||
|
end # hide
|
||||||
|
plot(plots...; size=(1200, 800)) # hide
|
||||||
|
end # hide
|
||||||
|
|
||||||
|
gif(anim, anim.dir * ".gif", fps=20) # hide
|
|
@ -27,7 +27,7 @@ export
|
||||||
include("Schedulers/Schedulers.jl")
|
include("Schedulers/Schedulers.jl")
|
||||||
|
|
||||||
import .Schedulers:
|
import .Schedulers:
|
||||||
# Scheduler
|
# Schedulers
|
||||||
DDPM,
|
DDPM,
|
||||||
DDIM,
|
DDIM,
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ import .Schedulers:
|
||||||
VELOCITY
|
VELOCITY
|
||||||
|
|
||||||
export
|
export
|
||||||
# Scheduler
|
# Schedulers
|
||||||
DDPM,
|
DDPM,
|
||||||
DDIM,
|
DDIM,
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue