🐛 (examples/swissroll) better performances + switch to linear beta schedule

This commit is contained in:
Laureηt 2023-08-04 21:48:27 +02:00
parent b931b5ceb5
commit 2cf9c49202
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
3 changed files with 24 additions and 17 deletions

View file

@ -1,9 +1,12 @@
[deps] [deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef"
Diffusers = "90edb7a8-79d7-49b2-b6b1-9322c3fdead8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31" ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef"

View file

@ -1,7 +1,7 @@
import Diffusers import Diffusers
import Diffusers.Schedulers import Diffusers.Schedulers
import Diffusers.Schedulers: DDPM import Diffusers.Schedulers: DDPM
import Diffusers.BetaSchedules: cosine_beta_schedule, rescale_zero_terminal_snr import Diffusers.BetaSchedules: linear_beta_schedule
using Flux using Flux
using Random using Random
using Plots using Plots
@ -10,7 +10,7 @@ using DenoisingDiffusion
using LaTeXStrings using LaTeXStrings
function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π) function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π)
t = rand(n_samples) * (t_max - t_min) .+ t_min t = rand(typeof(t_min), n_samples) * (t_max - t_min) .+ t_min
x = t .* cos.(t) x = t .* cos.(t)
y = t .* sin.(t) y = t .* sin.(t)
@ -29,7 +29,7 @@ function normalize_neg_one_to_one(x)
end end
n_points = 1000 n_points = 1000
dataset = make_spiral(n_points, 1.5π, 4.5π) dataset = make_spiral(n_points, 1.5f0 * π, 4.5f0 * π)
dataset = normalize_neg_one_to_one(dataset) dataset = normalize_neg_one_to_one(dataset)
scatter(dataset[1, :], dataset[2, :], scatter(dataset[1, :], dataset[2, :],
alpha=0.5, alpha=0.5,
@ -39,12 +39,12 @@ scatter(dataset[1, :], dataset[2, :],
num_timesteps = 100 num_timesteps = 100
scheduler = DDPM( scheduler = DDPM(
Vector{Float32}, Vector{Float32},
cosine_beta_schedule(num_timesteps) linear_beta_schedule(num_timesteps)
); );
data = dataset data = dataset
noise = randn(Float32, size(data)) noise = randn(Float32, size(data))
anim = @animate for t in cat(fill(0, 2), 1:num_timesteps, fill(num_timesteps, 2), dims=1) anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1)
if t == 0 if t == 0
scatter(noise[1, :], noise[2, :], scatter(noise[1, :], noise[2, :],
alpha=0.3, alpha=0.3,
@ -86,7 +86,7 @@ anim = @animate for t in cat(fill(0, 2), 1:num_timesteps, fill(num_timesteps, 2)
ylims!(-3, 3) ylims!(-3, 3)
end end
end end
gif(anim, anim.dir * ".gif", fps=2) gif(anim, anim.dir * ".gif", fps=20)
d_hid = 32 d_hid = 32
model = ConditionalChain( model = ConditionalChain(
@ -95,7 +95,8 @@ model = ConditionalChain(
Dense(2, d_hid), Dense(2, d_hid),
Chain( Chain(
SinusoidalPositionEmbedding(num_timesteps, d_hid), SinusoidalPositionEmbedding(num_timesteps, d_hid),
Dense(d_hid, d_hid)) Dense(d_hid, d_hid)
)
), ),
relu, relu,
Parallel( Parallel(
@ -103,7 +104,8 @@ model = ConditionalChain(
Dense(d_hid, d_hid), Dense(d_hid, d_hid),
Chain( Chain(
SinusoidalPositionEmbedding(num_timesteps, d_hid), SinusoidalPositionEmbedding(num_timesteps, d_hid),
Dense(d_hid, d_hid)) Dense(d_hid, d_hid)
)
), ),
relu, relu,
Parallel( Parallel(
@ -111,7 +113,8 @@ model = ConditionalChain(
Dense(d_hid, d_hid), Dense(d_hid, d_hid),
Chain( Chain(
SinusoidalPositionEmbedding(num_timesteps, d_hid), SinusoidalPositionEmbedding(num_timesteps, d_hid),
Dense(d_hid, d_hid)) Dense(d_hid, d_hid)
)
), ),
relu, relu,
Dense(d_hid, 2), Dense(d_hid, 2),
@ -119,9 +122,9 @@ model = ConditionalChain(
model(data, [5]) model(data, [5])
num_epochs = 10000; num_epochs = 5000;
loss = Flux.Losses.mse; loss = Flux.Losses.mse;
opt = Flux.setup(Adam(0.001), model); opt = Flux.setup(AdamW(), model);
dataloader = Flux.DataLoader(dataset; batchsize=32, shuffle=true); dataloader = Flux.DataLoader(dataset; batchsize=32, shuffle=true);
progress = Progress(num_epochs; desc="training", showspeed=true); progress = Progress(num_epochs; desc="training", showspeed=true);
for epoch = 1:num_epochs for epoch = 1:num_epochs
@ -154,7 +157,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
p1 = scatter(dataset[1, :], dataset[2, :], p1 = scatter(dataset[1, :], dataset[2, :],
alpha=0.01, alpha=0.01,
aspectratio=:equal, aspectratio=:equal,
title=L"x_t", title=L"\hat{x}_t",
legend=false, legend=false,
) )
scatter!(sample_old[1, :], sample_old[2, :]) scatter!(sample_old[1, :], sample_old[2, :])
@ -162,7 +165,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
p2 = scatter(dataset[1, :], dataset[2, :], p2 = scatter(dataset[1, :], dataset[2, :],
alpha=0.01, alpha=0.01,
aspectratio=:equal, aspectratio=:equal,
title=L"x_0", title=L"\hat{x}_0",
legend=false, legend=false,
) )
@ -171,6 +174,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
plot(p1, p2, plot(p1, p2,
layout=l, layout=l,
plot_title=latexstring("t = $(t_str)"), plot_title=latexstring("t = $(t_str)"),
size=(700, 400),
) )
xlims!(-2, 2) xlims!(-2, 2)
ylims!(-2, 2) ylims!(-2, 2)
@ -180,7 +184,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
alpha=0.01, alpha=0.01,
aspectratio=:equal, aspectratio=:equal,
legend=false, legend=false,
title=L"x_t", title=L"\hat{x}_t",
) )
scatter!(sample[1, :], sample[2, :]) scatter!(sample[1, :], sample[2, :])
@ -188,7 +192,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
alpha=0.01, alpha=0.01,
aspectratio=:equal, aspectratio=:equal,
legend=false, legend=false,
title=L"x_0", title=L"\hat{x}_0",
) )
scatter!(x_0[1, :], x_0[2, :]) scatter!(x_0[1, :], x_0[2, :])
@ -197,6 +201,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
plot(p1, p2, plot(p1, p2,
layout=l, layout=l,
plot_title=latexstring("t = $(t_str)"), plot_title=latexstring("t = $(t_str)"),
size=(700, 400),
) )
xlims!(-2, 2) xlims!(-2, 2)
ylims!(-2, 2) ylims!(-2, 2)

View file

@ -121,7 +121,6 @@ function step(
# arxiv:2006.11239 Eq. 6 # arxiv:2006.11239 Eq. 6
# arxiv:2208.11970 Eq. 70 # arxiv:2208.11970 Eq. 70
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler
σₜ = exp.(log.(σₜ) ./ 2) # https://github.com/huggingface/diffusers/blob/160474ac61934cc22793d6cebea118c171175dbc/src/diffusers/schedulers/scheduling_ddpm.py#L306
xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ)) xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ))
return xₜ₋₁, x̂₀ return xₜ₋₁, x̂₀