mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-10-17 23:26:19 +00:00
🐛 (examples/swissroll) better performances + switch to linear beta schedule
This commit is contained in:
parent
b931b5ceb5
commit
2cf9c49202
|
@ -1,9 +1,12 @@
|
||||||
[deps]
|
[deps]
|
||||||
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
|
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
|
||||||
|
DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef"
|
||||||
|
Diffusers = "90edb7a8-79d7-49b2-b6b1-9322c3fdead8"
|
||||||
|
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
|
||||||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||||
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
|
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
|
||||||
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
|
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
|
||||||
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
|
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
|
||||||
|
PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a"
|
||||||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
|
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
|
||||||
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
|
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
|
||||||
DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef"
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import Diffusers
|
import Diffusers
|
||||||
import Diffusers.Schedulers
|
import Diffusers.Schedulers
|
||||||
import Diffusers.Schedulers: DDPM
|
import Diffusers.Schedulers: DDPM
|
||||||
import Diffusers.BetaSchedules: cosine_beta_schedule, rescale_zero_terminal_snr
|
import Diffusers.BetaSchedules: linear_beta_schedule
|
||||||
using Flux
|
using Flux
|
||||||
using Random
|
using Random
|
||||||
using Plots
|
using Plots
|
||||||
|
@ -10,7 +10,7 @@ using DenoisingDiffusion
|
||||||
using LaTeXStrings
|
using LaTeXStrings
|
||||||
|
|
||||||
function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π)
|
function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π)
|
||||||
t = rand(n_samples) * (t_max - t_min) .+ t_min
|
t = rand(typeof(t_min), n_samples) * (t_max - t_min) .+ t_min
|
||||||
|
|
||||||
x = t .* cos.(t)
|
x = t .* cos.(t)
|
||||||
y = t .* sin.(t)
|
y = t .* sin.(t)
|
||||||
|
@ -29,7 +29,7 @@ function normalize_neg_one_to_one(x)
|
||||||
end
|
end
|
||||||
|
|
||||||
n_points = 1000
|
n_points = 1000
|
||||||
dataset = make_spiral(n_points, 1.5π, 4.5π)
|
dataset = make_spiral(n_points, 1.5f0 * π, 4.5f0 * π)
|
||||||
dataset = normalize_neg_one_to_one(dataset)
|
dataset = normalize_neg_one_to_one(dataset)
|
||||||
scatter(dataset[1, :], dataset[2, :],
|
scatter(dataset[1, :], dataset[2, :],
|
||||||
alpha=0.5,
|
alpha=0.5,
|
||||||
|
@ -39,12 +39,12 @@ scatter(dataset[1, :], dataset[2, :],
|
||||||
num_timesteps = 100
|
num_timesteps = 100
|
||||||
scheduler = DDPM(
|
scheduler = DDPM(
|
||||||
Vector{Float32},
|
Vector{Float32},
|
||||||
cosine_beta_schedule(num_timesteps)
|
linear_beta_schedule(num_timesteps)
|
||||||
);
|
);
|
||||||
|
|
||||||
data = dataset
|
data = dataset
|
||||||
noise = randn(Float32, size(data))
|
noise = randn(Float32, size(data))
|
||||||
anim = @animate for t in cat(fill(0, 2), 1:num_timesteps, fill(num_timesteps, 2), dims=1)
|
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1)
|
||||||
if t == 0
|
if t == 0
|
||||||
scatter(noise[1, :], noise[2, :],
|
scatter(noise[1, :], noise[2, :],
|
||||||
alpha=0.3,
|
alpha=0.3,
|
||||||
|
@ -86,7 +86,7 @@ anim = @animate for t in cat(fill(0, 2), 1:num_timesteps, fill(num_timesteps, 2)
|
||||||
ylims!(-3, 3)
|
ylims!(-3, 3)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
gif(anim, anim.dir * ".gif", fps=2)
|
gif(anim, anim.dir * ".gif", fps=20)
|
||||||
|
|
||||||
d_hid = 32
|
d_hid = 32
|
||||||
model = ConditionalChain(
|
model = ConditionalChain(
|
||||||
|
@ -95,7 +95,8 @@ model = ConditionalChain(
|
||||||
Dense(2, d_hid),
|
Dense(2, d_hid),
|
||||||
Chain(
|
Chain(
|
||||||
SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
||||||
Dense(d_hid, d_hid))
|
Dense(d_hid, d_hid)
|
||||||
|
)
|
||||||
),
|
),
|
||||||
relu,
|
relu,
|
||||||
Parallel(
|
Parallel(
|
||||||
|
@ -103,7 +104,8 @@ model = ConditionalChain(
|
||||||
Dense(d_hid, d_hid),
|
Dense(d_hid, d_hid),
|
||||||
Chain(
|
Chain(
|
||||||
SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
||||||
Dense(d_hid, d_hid))
|
Dense(d_hid, d_hid)
|
||||||
|
)
|
||||||
),
|
),
|
||||||
relu,
|
relu,
|
||||||
Parallel(
|
Parallel(
|
||||||
|
@ -111,7 +113,8 @@ model = ConditionalChain(
|
||||||
Dense(d_hid, d_hid),
|
Dense(d_hid, d_hid),
|
||||||
Chain(
|
Chain(
|
||||||
SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
||||||
Dense(d_hid, d_hid))
|
Dense(d_hid, d_hid)
|
||||||
|
)
|
||||||
),
|
),
|
||||||
relu,
|
relu,
|
||||||
Dense(d_hid, 2),
|
Dense(d_hid, 2),
|
||||||
|
@ -119,9 +122,9 @@ model = ConditionalChain(
|
||||||
|
|
||||||
model(data, [5])
|
model(data, [5])
|
||||||
|
|
||||||
num_epochs = 10000;
|
num_epochs = 5000;
|
||||||
loss = Flux.Losses.mse;
|
loss = Flux.Losses.mse;
|
||||||
opt = Flux.setup(Adam(0.001), model);
|
opt = Flux.setup(AdamW(), model);
|
||||||
dataloader = Flux.DataLoader(dataset; batchsize=32, shuffle=true);
|
dataloader = Flux.DataLoader(dataset; batchsize=32, shuffle=true);
|
||||||
progress = Progress(num_epochs; desc="training", showspeed=true);
|
progress = Progress(num_epochs; desc="training", showspeed=true);
|
||||||
for epoch = 1:num_epochs
|
for epoch = 1:num_epochs
|
||||||
|
@ -154,7 +157,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
|
||||||
p1 = scatter(dataset[1, :], dataset[2, :],
|
p1 = scatter(dataset[1, :], dataset[2, :],
|
||||||
alpha=0.01,
|
alpha=0.01,
|
||||||
aspectratio=:equal,
|
aspectratio=:equal,
|
||||||
title=L"x_t",
|
title=L"\hat{x}_t",
|
||||||
legend=false,
|
legend=false,
|
||||||
)
|
)
|
||||||
scatter!(sample_old[1, :], sample_old[2, :])
|
scatter!(sample_old[1, :], sample_old[2, :])
|
||||||
|
@ -162,7 +165,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
|
||||||
p2 = scatter(dataset[1, :], dataset[2, :],
|
p2 = scatter(dataset[1, :], dataset[2, :],
|
||||||
alpha=0.01,
|
alpha=0.01,
|
||||||
aspectratio=:equal,
|
aspectratio=:equal,
|
||||||
title=L"x_0",
|
title=L"\hat{x}_0",
|
||||||
legend=false,
|
legend=false,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -171,6 +174,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
|
||||||
plot(p1, p2,
|
plot(p1, p2,
|
||||||
layout=l,
|
layout=l,
|
||||||
plot_title=latexstring("t = $(t_str)"),
|
plot_title=latexstring("t = $(t_str)"),
|
||||||
|
size=(700, 400),
|
||||||
)
|
)
|
||||||
xlims!(-2, 2)
|
xlims!(-2, 2)
|
||||||
ylims!(-2, 2)
|
ylims!(-2, 2)
|
||||||
|
@ -180,7 +184,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
|
||||||
alpha=0.01,
|
alpha=0.01,
|
||||||
aspectratio=:equal,
|
aspectratio=:equal,
|
||||||
legend=false,
|
legend=false,
|
||||||
title=L"x_t",
|
title=L"\hat{x}_t",
|
||||||
)
|
)
|
||||||
scatter!(sample[1, :], sample[2, :])
|
scatter!(sample[1, :], sample[2, :])
|
||||||
|
|
||||||
|
@ -188,7 +192,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
|
||||||
alpha=0.01,
|
alpha=0.01,
|
||||||
aspectratio=:equal,
|
aspectratio=:equal,
|
||||||
legend=false,
|
legend=false,
|
||||||
title=L"x_0",
|
title=L"\hat{x}_0",
|
||||||
)
|
)
|
||||||
scatter!(x_0[1, :], x_0[2, :])
|
scatter!(x_0[1, :], x_0[2, :])
|
||||||
|
|
||||||
|
@ -197,6 +201,7 @@ anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
|
||||||
plot(p1, p2,
|
plot(p1, p2,
|
||||||
layout=l,
|
layout=l,
|
||||||
plot_title=latexstring("t = $(t_str)"),
|
plot_title=latexstring("t = $(t_str)"),
|
||||||
|
size=(700, 400),
|
||||||
)
|
)
|
||||||
xlims!(-2, 2)
|
xlims!(-2, 2)
|
||||||
ylims!(-2, 2)
|
ylims!(-2, 2)
|
||||||
|
|
|
@ -121,7 +121,6 @@ function step(
|
||||||
# arxiv:2006.11239 Eq. 6
|
# arxiv:2006.11239 Eq. 6
|
||||||
# arxiv:2208.11970 Eq. 70
|
# arxiv:2208.11970 Eq. 70
|
||||||
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler
|
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler
|
||||||
σₜ = exp.(log.(σₜ) ./ 2) # https://github.com/huggingface/diffusers/blob/160474ac61934cc22793d6cebea118c171175dbc/src/diffusers/schedulers/scheduling_ddpm.py#L306
|
|
||||||
xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ))
|
xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ))
|
||||||
|
|
||||||
return xₜ₋₁, x̂₀
|
return xₜ₋₁, x̂₀
|
||||||
|
|
Loading…
Reference in a new issue