🐛 (examples/swissroll) better performances + switch to linear beta schedule

This commit is contained in:
Laureηt 2023-08-04 21:48:27 +02:00
parent b931b5ceb5
commit 2cf9c49202
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
3 changed files with 24 additions and 17 deletions

View file

@ -1,9 +1,12 @@
[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef"
Diffusers = "90edb7a8-79d7-49b2-b6b1-9322c3fdead8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef"

View file

@ -1,7 +1,7 @@
import Diffusers
import Diffusers.Schedulers
import Diffusers.Schedulers: DDPM
import Diffusers.BetaSchedules: cosine_beta_schedule, rescale_zero_terminal_snr
import Diffusers.BetaSchedules: linear_beta_schedule
using Flux
using Random
using Plots
@ -10,7 +10,7 @@ using DenoisingDiffusion
using LaTeXStrings
function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π)
t = rand(n_samples) * (t_max - t_min) .+ t_min
t = rand(typeof(t_min), n_samples) * (t_max - t_min) .+ t_min
x = t .* cos.(t)
y = t .* sin.(t)
@ -29,7 +29,7 @@ function normalize_neg_one_to_one(x)
end
n_points = 1000
dataset = make_spiral(n_points, 1.5π, 4.5π)
dataset = make_spiral(n_points, 1.5f0 * π, 4.5f0 * π)
dataset = normalize_neg_one_to_one(dataset)
scatter(dataset[1, :], dataset[2, :],
alpha=0.5,
@ -39,12 +39,12 @@ scatter(dataset[1, :], dataset[2, :],
num_timesteps = 100
scheduler = DDPM(
Vector{Float32},
cosine_beta_schedule(num_timesteps)
linear_beta_schedule(num_timesteps)
);
data = dataset
noise = randn(Float32, size(data))
anim = @animate for t in cat(fill(0, 2), 1:num_timesteps, fill(num_timesteps, 2), dims=1)
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1)
if t == 0
scatter(noise[1, :], noise[2, :],
alpha=0.3,
@ -86,7 +86,7 @@ anim = @animate for t in cat(fill(0, 2), 1:num_timesteps, fill(num_timesteps, 2)
ylims!(-3, 3)
end
end
gif(anim, anim.dir * ".gif", fps=2)
gif(anim, anim.dir * ".gif", fps=20)
d_hid = 32
model = ConditionalChain(
@ -95,7 +95,8 @@ model = ConditionalChain(
Dense(2, d_hid),
Chain(
SinusoidalPositionEmbedding(num_timesteps, d_hid),
Dense(d_hid, d_hid))
Dense(d_hid, d_hid)
)
),
relu,
Parallel(
@ -103,7 +104,8 @@ model = ConditionalChain(
Dense(d_hid, d_hid),
Chain(
SinusoidalPositionEmbedding(num_timesteps, d_hid),
Dense(d_hid, d_hid))
Dense(d_hid, d_hid)
)
),
relu,
Parallel(
@ -111,7 +113,8 @@ model = ConditionalChain(
Dense(d_hid, d_hid),
Chain(
SinusoidalPositionEmbedding(num_timesteps, d_hid),
Dense(d_hid, d_hid))
Dense(d_hid, d_hid)
)
),
relu,
Dense(d_hid, 2),
@ -119,9 +122,9 @@ model = ConditionalChain(
model(data, [5])
num_epochs = 10000;
num_epochs = 5000;
loss = Flux.Losses.mse;
opt = Flux.setup(Adam(0.001), model);
opt = Flux.setup(AdamW(), model);
dataloader = Flux.DataLoader(dataset; batchsize=32, shuffle=true);
progress = Progress(num_epochs; desc="training", showspeed=true);
for epoch = 1:num_epochs
@ -154,7 +157,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
p1 = scatter(dataset[1, :], dataset[2, :],
alpha=0.01,
aspectratio=:equal,
title=L"x_t",
title=L"\hat{x}_t",
legend=false,
)
scatter!(sample_old[1, :], sample_old[2, :])
@ -162,7 +165,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
p2 = scatter(dataset[1, :], dataset[2, :],
alpha=0.01,
aspectratio=:equal,
title=L"x_0",
title=L"\hat{x}_0",
legend=false,
)
@ -171,6 +174,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
plot(p1, p2,
layout=l,
plot_title=latexstring("t = $(t_str)"),
size=(700, 400),
)
xlims!(-2, 2)
ylims!(-2, 2)
@ -180,7 +184,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
alpha=0.01,
aspectratio=:equal,
legend=false,
title=L"x_t",
title=L"\hat{x}_t",
)
scatter!(sample[1, :], sample[2, :])
@ -188,7 +192,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
alpha=0.01,
aspectratio=:equal,
legend=false,
title=L"x_0",
title=L"\hat{x}_0",
)
scatter!(x_0[1, :], x_0[2, :])
@ -197,6 +201,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
plot(p1, p2,
layout=l,
plot_title=latexstring("t = $(t_str)"),
size=(700, 400),
)
xlims!(-2, 2)
ylims!(-2, 2)

View file

@ -121,7 +121,6 @@ function step(
# arxiv:2006.11239 Eq. 6
# arxiv:2208.11970 Eq. 70
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler
σₜ = exp.(log.(σₜ) ./ 2) # https://github.com/huggingface/diffusers/blob/160474ac61934cc22793d6cebea118c171175dbc/src/diffusers/schedulers/scheduling_ddpm.py#L306
xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ))
return xₜ₋₁, x̂₀