diff --git a/.gitignore b/.gitignore index 735d252..56a96d1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ .direnv *.gif *.jld2 +generated/ # https://github.com/github/gitignore/blob/main/Julia.gitignore # Files generated by invoking Julia with --code-coverage diff --git a/docs/Project.toml b/docs/Project.toml index db611c8..0ee5a5e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,12 @@ [deps] +DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef" Diffusers = "90edb7a8-79d7-49b2-b6b1-9322c3fdead8" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" diff --git a/docs/make.jl b/docs/make.jl index bed2aba..88292b9 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,9 @@ using Diffusers using Documenter using DocumenterCitations +using Literate + +Literate.markdown(joinpath(@__DIR__, "..", "examples", "beta_schedulers_comparison.jl"), joinpath(@__DIR__, "src", "generated")) DocMeta.setdocmeta!(Diffusers, :DocTestSetup, :(using Diffusers); recursive=true) bib = CitationBibliography(joinpath(@__DIR__, "src", "refs.bib")) @@ -18,8 +21,13 @@ makedocs(bib; linkcheck=true, pages=[ "Home" => "index.md", - "Schedulers" => "schedulers.md", - "Beta Schedules" => "beta_schedules.md", + "API" => [ + "Schedulers" => "schedulers.md", + "Beta Schedules" => "beta_schedules.md", + ], + "Examples" => [ + "Beta Schedules Comparison" => "generated/beta_schedulers_comparison.md", + ], "References" => "references.md", ] ) diff --git a/docs/src/refs.bib b/docs/src/refs.bib index a25be3b..125483d 100644 --- a/docs/src/refs.bib +++ b/docs/src/refs.bib @@ -33,3 +33,30 @@ archiveprefix = {arXiv}, primaryclass = {cs.LG} } + +@misc{lin2023common, + title = {Common Diffusion Noise Schedules and Sample Steps are Flawed}, + author = {Shanchuan Lin and Bingchen Liu and Jiashi Li and Xiao Yang}, + year = {2023}, + eprint = {2305.08891}, + archiveprefix = {arXiv}, + primaryclass = {cs.CV} +} + +@misc{xu2022geodiff, + title = {GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation}, + author = {Minkai Xu and Lantao Yu and Yang Song and Chence Shi and Stefano Ermon and Jian Tang}, + year = {2022}, + eprint = {2203.02923}, + archiveprefix = {arXiv}, + primaryclass = {cs.LG} +} + +@misc{salimans2022progressive, + title = {Progressive Distillation for Fast Sampling of Diffusion Models}, + author = {Tim Salimans and Jonathan Ho}, + year = {2022}, + eprint = {2202.00512}, + archiveprefix = {arXiv}, + primaryclass = {cs.LG} +} \ No newline at end of file diff --git a/examples/beta_schedulers_comparison.jl b/examples/beta_schedulers_comparison.jl new file mode 100644 index 0000000..8b386e1 --- /dev/null +++ b/examples/beta_schedulers_comparison.jl @@ -0,0 +1,180 @@ +# This example compares the different beta schedules available in Diffusers.jl. +# Code related to the generation of the datasets and the plots is hidden. + +using Diffusers.Schedulers: DDPM, forward # hide +using Diffusers.BetaSchedules # hide +using ProgressMeter # hide +using LaTeXStrings # hide +using Random # hide +using Plots # hide +using Flux # hide +using MLDatasets # hide + +function normalize_zero_to_one(x) # hide + x_min, x_max = extrema(x) # hide + x_norm = (x .- x_min) ./ (x_max - x_min) # hide + x_norm # hide +end # hide + +function normalize_neg_one_to_one(x) # hide + 2 * normalize_zero_to_one(x) .- 1 # hide +end # hide + +num_timesteps = 100 # hide +beta_schedules = [ # hide + linear_beta_schedule, # hide + scaled_linear_beta_schedule, # hide + cosine_beta_schedule, # hide + sigmoid_beta_schedule, # hide + exponential_beta_schedule, # hide +] # hide +schedulers = [ # hide + DDPM(collect(schedule(num_timesteps))) for schedule in beta_schedules # hide +]; # hide + +# ## Swiss Roll + +function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π) # hide + t = rand(typeof(t_min), n_samples) * (t_max - t_min) .+ t_min # hide + + x = t .* cos.(t) # hide + y = t .* sin.(t) # hide + + permutedims([x y], (2, 1)) # hide +end # hide + +n_points = 1000; # hide +dataset = make_spiral(n_points, 1.5f0 * π, 4.5f0 * π); # hide +dataset = normalize_neg_one_to_one(dataset); # hide + +noise = randn(Float32, size(dataset)) # hide +anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide + plots = [] # hide + for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide + if t == 0 # hide + scatter(dataset[1, :], dataset[2, :], # hide + alpha=0.5, # hide + aspectratio=:equal, # hide + legend=false, # hide + ) # hide + plot = scatter!(dataset[1, :], dataset[2, :], # hide + alpha=0.5, # hide + aspectratio=:equal, # hide + ) # hide + title!(string(beta_schedule)) # hide + xlims!(-3, 3) # hide + ylims!(-3, 3) # hide + else # hide + scatter(dataset[1, :], dataset[2, :], # hide + alpha=0.5, # hide + aspectratio=:equal, # hide + legend=false, # hide + ) # hide + noisy_data = forward(scheduler, dataset, noise, [t]) # hide + plot = scatter!(noisy_data[1, :], noisy_data[2, :], # hide + alpha=0.5, # hide + aspectratio=:equal, # hide + ) # hide + title!(string(beta_schedule)) # hide + xlims!(-3, 3) # hide + ylims!(-3, 3) # hide + end # hide + push!(plots, plot) # hide + end # hide + plot(plots...; size=(1200, 800)) # hide +end # hide + +gif(anim, anim.dir * ".gif", fps=20) # hide + +# ## Double Square + +function make_square(n_samples::Integer=1000) # hide + x = rand(n_samples) .* 2 .- 1 # hide + y = rand(n_samples) .* 2 .- 1 # hide + p = permutedims([x y], (2, 1)) # hide + p ./ maximum(abs.(p), dims=1) # hide +end # hide + +dataset = hcat( # hide + make_square(Int(n_points / 2)) ./ 2 .- 1.5, # hide + make_square(Int(n_points / 2)) ./ 2 .+ 1.5 # hide +) # hide + +noise = randn(Float32, size(dataset)) # hide +anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide + plots = [] # hide + for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide + if t == 0 # hide + scatter(dataset[1, :], dataset[2, :], # hide + alpha=0.5, # hide + aspectratio=:equal, # hide + legend=false, # hide + ) # hide + plot = scatter!(dataset[1, :], dataset[2, :], # hide + alpha=0.5, # hide + aspectratio=:equal, # hide + ) # hide + title!(string(beta_schedule)) # hide + xlims!(-3, 3) # hide + ylims!(-3, 3) # hide + else # hide + scatter(dataset[1, :], dataset[2, :], # hide + alpha=0.5, # hide + aspectratio=:equal, # hide + legend=false, # hide + ) # hide + noisy_data = forward(scheduler, dataset, noise, [t]) # hide + plot = scatter!(noisy_data[1, :], noisy_data[2, :], # hide + alpha=0.5, # hide + aspectratio=:equal, # hide + ) # hide + title!(string(beta_schedule)) # hide + xlims!(-3, 3) # hide + ylims!(-3, 3) # hide + end # hide + push!(plots, plot) # hide + end # hide + plot(plots...; size=(1200, 800)) # hide +end # hide + +gif(anim, anim.dir * ".gif", fps=20) # hide + +# ## MNIST + +dataset = MNIST(:test)[16].features # hide +dataset = rotl90(dataset) # hide +dataset = normalize_neg_one_to_one(dataset) # hide + +noise = randn(Float32, size(dataset)) # hide +anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide + plots = [] # hide + for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide + if t == 0 # hide + plot = heatmap( # hide + dataset, # hide + c=:grayC, # hide + legend=:none, # hide + aspect_ratio=:equal, # hide + grid=false, # hide + axis=false # hide + ) # hide + title!(string(beta_schedule)) # hide + else # hide + noisy_data = forward(scheduler, dataset, noise, [t]) # hide + plot = heatmap( # hide + noisy_data, # hide + c=:grayC, # hide + legend=:none, # hide + aspect_ratio=:equal, # hide + grid=false, # hide + axis=false # hide + ) # hide + title!(string(beta_schedule)) # hide + end # hide + + push!(plots, plot) # hide + end # hide + plot(plots...; size=(1200, 800)) # hide +end # hide + +gif(anim, anim.dir * ".gif", fps=20) # hide \ No newline at end of file diff --git a/src/Diffusers.jl b/src/Diffusers.jl index 0b0bba7..3e3ea63 100644 --- a/src/Diffusers.jl +++ b/src/Diffusers.jl @@ -27,7 +27,7 @@ export include("Schedulers/Schedulers.jl") import .Schedulers: - # Scheduler + # Schedulers DDPM, DDIM, @@ -51,7 +51,7 @@ import .Schedulers: VELOCITY export - # Scheduler + # Schedulers DDPM, DDIM,