diff --git a/examples/Project.toml b/examples/Project.toml index 3960d7f..9504ed5 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -1,9 +1,12 @@ [deps] BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" +DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef" +Diffusers = "90edb7a8-79d7-49b2-b6b1-9322c3fdead8" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31" LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" -DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef" diff --git a/examples/swissroll.jl b/examples/swissroll.jl index b8802ac..bc7eca1 100644 --- a/examples/swissroll.jl +++ b/examples/swissroll.jl @@ -1,7 +1,7 @@ import Diffusers import Diffusers.Schedulers import Diffusers.Schedulers: DDPM -import Diffusers.BetaSchedules: cosine_beta_schedule, rescale_zero_terminal_snr +import Diffusers.BetaSchedules: linear_beta_schedule using Flux using Random using Plots @@ -10,7 +10,7 @@ using DenoisingDiffusion using LaTeXStrings function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π) - t = rand(n_samples) * (t_max - t_min) .+ t_min + t = rand(typeof(t_min), n_samples) * (t_max - t_min) .+ t_min x = t .* cos.(t) y = t .* sin.(t) @@ -29,7 +29,7 @@ function normalize_neg_one_to_one(x) end n_points = 1000 -dataset = make_spiral(n_points, 1.5π, 4.5π) +dataset = make_spiral(n_points, 1.5f0 * π, 4.5f0 * π) dataset = normalize_neg_one_to_one(dataset) scatter(dataset[1, :], dataset[2, :], alpha=0.5, @@ -39,12 +39,12 @@ scatter(dataset[1, :], dataset[2, :], num_timesteps = 100 scheduler = DDPM( Vector{Float32}, - cosine_beta_schedule(num_timesteps) + linear_beta_schedule(num_timesteps) ); data = dataset noise = randn(Float32, size(data)) -anim = @animate for t in cat(fill(0, 2), 1:num_timesteps, fill(num_timesteps, 2), dims=1) +anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) if t == 0 scatter(noise[1, :], noise[2, :], alpha=0.3, @@ -86,7 +86,7 @@ anim = @animate for t in cat(fill(0, 2), 1:num_timesteps, fill(num_timesteps, 2) ylims!(-3, 3) end end -gif(anim, anim.dir * ".gif", fps=2) +gif(anim, anim.dir * ".gif", fps=20) d_hid = 32 model = ConditionalChain( @@ -95,7 +95,8 @@ model = ConditionalChain( Dense(2, d_hid), Chain( SinusoidalPositionEmbedding(num_timesteps, d_hid), - Dense(d_hid, d_hid)) + Dense(d_hid, d_hid) + ) ), relu, Parallel( @@ -103,7 +104,8 @@ model = ConditionalChain( Dense(d_hid, d_hid), Chain( SinusoidalPositionEmbedding(num_timesteps, d_hid), - Dense(d_hid, d_hid)) + Dense(d_hid, d_hid) + ) ), relu, Parallel( @@ -111,7 +113,8 @@ model = ConditionalChain( Dense(d_hid, d_hid), Chain( SinusoidalPositionEmbedding(num_timesteps, d_hid), - Dense(d_hid, d_hid)) + Dense(d_hid, d_hid) + ) ), relu, Dense(d_hid, 2), @@ -119,9 +122,9 @@ model = ConditionalChain( model(data, [5]) -num_epochs = 10000; +num_epochs = 5000; loss = Flux.Losses.mse; -opt = Flux.setup(Adam(0.001), model); +opt = Flux.setup(AdamW(), model); dataloader = Flux.DataLoader(dataset; batchsize=32, shuffle=true); progress = Progress(num_epochs; desc="training", showspeed=true); for epoch = 1:num_epochs @@ -154,7 +157,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2 p1 = scatter(dataset[1, :], dataset[2, :], alpha=0.01, aspectratio=:equal, - title=L"x_t", + title=L"\hat{x}_t", legend=false, ) scatter!(sample_old[1, :], sample_old[2, :]) @@ -162,7 +165,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2 p2 = scatter(dataset[1, :], dataset[2, :], alpha=0.01, aspectratio=:equal, - title=L"x_0", + title=L"\hat{x}_0", legend=false, ) @@ -171,6 +174,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2 plot(p1, p2, layout=l, plot_title=latexstring("t = $(t_str)"), + size=(700, 400), ) xlims!(-2, 2) ylims!(-2, 2) @@ -180,7 +184,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2 alpha=0.01, aspectratio=:equal, legend=false, - title=L"x_t", + title=L"\hat{x}_t", ) scatter!(sample[1, :], sample[2, :]) @@ -188,7 +192,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2 alpha=0.01, aspectratio=:equal, legend=false, - title=L"x_0", + title=L"\hat{x}_0", ) scatter!(x_0[1, :], x_0[2, :]) @@ -197,6 +201,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2 plot(p1, p2, layout=l, plot_title=latexstring("t = $(t_str)"), + size=(700, 400), ) xlims!(-2, 2) ylims!(-2, 2) diff --git a/src/Schedulers/DDPM.jl b/src/Schedulers/DDPM.jl index 0bd7766..846e6f1 100644 --- a/src/Schedulers/DDPM.jl +++ b/src/Schedulers/DDPM.jl @@ -121,7 +121,6 @@ function step( # arxiv:2006.11239 Eq. 6 # arxiv:2208.11970 Eq. 70 σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler - σₜ = exp.(log.(σₜ) ./ 2) # https://github.com/huggingface/diffusers/blob/160474ac61934cc22793d6cebea118c171175dbc/src/diffusers/schedulers/scheduling_ddpm.py#L306 xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ)) return xₜ₋₁, x̂₀