2023-07-03 19:12:22 +00:00
|
|
|
import Diffusers
|
2023-07-29 15:52:36 +00:00
|
|
|
import Diffusers.Schedulers
|
|
|
|
import Diffusers.Schedulers: DDPM
|
2023-08-04 19:48:27 +00:00
|
|
|
import Diffusers.BetaSchedules: linear_beta_schedule
|
2023-07-05 19:13:26 +00:00
|
|
|
using Flux
|
2023-07-03 19:12:22 +00:00
|
|
|
using Random
|
|
|
|
using Plots
|
2023-07-23 12:48:15 +00:00
|
|
|
using ProgressMeter
|
2023-07-30 15:55:22 +00:00
|
|
|
using DenoisingDiffusion
|
|
|
|
using LaTeXStrings
|
2023-07-03 19:12:22 +00:00
|
|
|
|
2023-07-30 15:55:22 +00:00
|
|
|
function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π)
|
2023-08-04 19:48:27 +00:00
|
|
|
t = rand(typeof(t_min), n_samples) * (t_max - t_min) .+ t_min
|
2023-07-03 19:12:22 +00:00
|
|
|
|
|
|
|
x = t .* cos.(t)
|
|
|
|
y = t .* sin.(t)
|
|
|
|
|
|
|
|
permutedims([x y], (2, 1))
|
|
|
|
end
|
|
|
|
|
|
|
|
function normalize_zero_to_one(x)
|
|
|
|
x_min, x_max = extrema(x)
|
|
|
|
x_norm = (x .- x_min) ./ (x_max - x_min)
|
|
|
|
x_norm
|
|
|
|
end
|
|
|
|
|
|
|
|
function normalize_neg_one_to_one(x)
|
|
|
|
2 * normalize_zero_to_one(x) .- 1
|
|
|
|
end
|
|
|
|
|
2023-08-01 20:59:12 +00:00
|
|
|
n_points = 1000
|
2023-08-04 19:48:27 +00:00
|
|
|
dataset = make_spiral(n_points, 1.5f0 * π, 4.5f0 * π)
|
2023-07-30 15:55:22 +00:00
|
|
|
dataset = normalize_neg_one_to_one(dataset)
|
|
|
|
scatter(dataset[1, :], dataset[2, :],
|
2023-07-03 19:12:22 +00:00
|
|
|
alpha=0.5,
|
|
|
|
aspectratio=:equal,
|
|
|
|
)
|
|
|
|
|
2023-07-05 19:13:26 +00:00
|
|
|
num_timesteps = 100
|
2023-07-29 15:52:36 +00:00
|
|
|
scheduler = DDPM(
|
2023-08-04 19:48:27 +00:00
|
|
|
linear_beta_schedule(num_timesteps)
|
2023-07-30 15:55:22 +00:00
|
|
|
);
|
2023-07-03 19:12:22 +00:00
|
|
|
|
2023-08-01 20:59:12 +00:00
|
|
|
data = dataset
|
|
|
|
noise = randn(Float32, size(data))
|
2023-08-04 19:48:27 +00:00
|
|
|
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1)
|
2023-07-30 15:55:22 +00:00
|
|
|
if t == 0
|
|
|
|
scatter(noise[1, :], noise[2, :],
|
|
|
|
alpha=0.3,
|
|
|
|
aspectratio=:equal,
|
|
|
|
label="noise",
|
|
|
|
legend=:outertopright,
|
|
|
|
)
|
|
|
|
scatter!(data[1, :], data[2, :],
|
|
|
|
alpha=0.3,
|
|
|
|
aspectratio=:equal,
|
|
|
|
label="data",
|
|
|
|
)
|
|
|
|
scatter!(data[1, :], data[2, :],
|
|
|
|
aspectratio=:equal,
|
|
|
|
label="noisy data",
|
|
|
|
)
|
2023-08-01 20:59:12 +00:00
|
|
|
title!(latexstring("t = " * lpad(t, 3, "0")))
|
2023-07-30 15:55:22 +00:00
|
|
|
xlims!(-3, 3)
|
|
|
|
ylims!(-3, 3)
|
|
|
|
else
|
2023-08-14 14:30:35 +00:00
|
|
|
noisy_data = Diffusers.Schedulers.forward(scheduler, data, noise, [t])
|
2023-07-30 15:55:22 +00:00
|
|
|
scatter(noise[1, :], noise[2, :],
|
|
|
|
alpha=0.3,
|
|
|
|
aspectratio=:equal,
|
|
|
|
label="noise",
|
|
|
|
legend=:outertopright,
|
|
|
|
)
|
|
|
|
scatter!(data[1, :], data[2, :],
|
|
|
|
alpha=0.3,
|
|
|
|
aspectratio=:equal,
|
|
|
|
label="data",
|
|
|
|
)
|
|
|
|
scatter!(noisy_data[1, :], noisy_data[2, :],
|
|
|
|
aspectratio=:equal,
|
|
|
|
label="noisy data",
|
|
|
|
)
|
|
|
|
title!(latexstring("t = " * lpad(t, 3, "0")))
|
|
|
|
xlims!(-3, 3)
|
|
|
|
ylims!(-3, 3)
|
|
|
|
end
|
2023-07-03 19:12:22 +00:00
|
|
|
end
|
2023-08-04 19:48:27 +00:00
|
|
|
gif(anim, anim.dir * ".gif", fps=20)
|
2023-07-05 19:13:26 +00:00
|
|
|
|
|
|
|
d_hid = 32
|
2023-07-24 19:05:47 +00:00
|
|
|
model = ConditionalChain(
|
2023-07-05 19:13:26 +00:00
|
|
|
Parallel(
|
|
|
|
.+,
|
|
|
|
Dense(2, d_hid),
|
|
|
|
Chain(
|
2023-07-24 19:05:47 +00:00
|
|
|
SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
2023-08-04 19:48:27 +00:00
|
|
|
Dense(d_hid, d_hid)
|
|
|
|
)
|
2023-07-05 19:13:26 +00:00
|
|
|
),
|
|
|
|
relu,
|
|
|
|
Parallel(
|
|
|
|
.+,
|
|
|
|
Dense(d_hid, d_hid),
|
|
|
|
Chain(
|
2023-07-24 19:05:47 +00:00
|
|
|
SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
2023-08-04 19:48:27 +00:00
|
|
|
Dense(d_hid, d_hid)
|
|
|
|
)
|
2023-07-05 19:13:26 +00:00
|
|
|
),
|
|
|
|
relu,
|
|
|
|
Parallel(
|
|
|
|
.+,
|
|
|
|
Dense(d_hid, d_hid),
|
|
|
|
Chain(
|
2023-07-24 19:05:47 +00:00
|
|
|
SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
2023-08-04 19:48:27 +00:00
|
|
|
Dense(d_hid, d_hid)
|
|
|
|
)
|
2023-07-05 19:13:26 +00:00
|
|
|
),
|
|
|
|
relu,
|
|
|
|
Dense(d_hid, 2),
|
|
|
|
)
|
|
|
|
|
2023-08-01 20:59:12 +00:00
|
|
|
model(data, [5])
|
2023-07-05 19:13:26 +00:00
|
|
|
|
2023-08-04 19:48:27 +00:00
|
|
|
num_epochs = 5000;
|
2023-07-30 15:55:22 +00:00
|
|
|
loss = Flux.Losses.mse;
|
2023-08-04 19:48:27 +00:00
|
|
|
opt = Flux.setup(AdamW(), model);
|
2023-08-01 20:59:12 +00:00
|
|
|
dataloader = Flux.DataLoader(dataset; batchsize=32, shuffle=true);
|
2023-07-30 15:55:22 +00:00
|
|
|
progress = Progress(num_epochs; desc="training", showspeed=true);
|
2023-07-05 19:13:26 +00:00
|
|
|
for epoch = 1:num_epochs
|
|
|
|
for data in dataloader
|
2023-08-01 20:59:12 +00:00
|
|
|
noise = randn(Float32, size(data))
|
|
|
|
timesteps = rand(1:num_timesteps, size(data, ndims(data)))
|
2023-08-14 14:30:35 +00:00
|
|
|
noisy_data = Diffusers.Schedulers.forward(scheduler, data, noise, timesteps)
|
2023-07-05 19:13:26 +00:00
|
|
|
grads = Flux.gradient(model) do m
|
2023-07-23 12:48:15 +00:00
|
|
|
model_output = m(noisy_data, timesteps)
|
2023-08-01 20:59:12 +00:00
|
|
|
loss(noise, model_output)
|
2023-07-05 19:13:26 +00:00
|
|
|
end
|
2023-08-01 20:59:12 +00:00
|
|
|
Flux.update!(opt, model, grads[1])
|
2023-07-05 19:13:26 +00:00
|
|
|
end
|
2023-07-23 12:48:15 +00:00
|
|
|
ProgressMeter.next!(progress)
|
|
|
|
end
|
|
|
|
|
2023-07-30 15:55:22 +00:00
|
|
|
## sampling animation
|
2023-07-25 19:01:56 +00:00
|
|
|
|
2023-08-01 20:59:12 +00:00
|
|
|
sample = randn(MersenneTwister(1), Float32, 2, 100)
|
2023-07-30 15:55:22 +00:00
|
|
|
sample_old = sample
|
|
|
|
predictions = []
|
2023-08-01 20:59:12 +00:00
|
|
|
for timestep in num_timesteps:-1:1
|
|
|
|
model_output = model(sample, [timestep])
|
2023-08-14 14:30:35 +00:00
|
|
|
sample, x0_pred = Diffusers.Schedulers.reverse(scheduler, sample, model_output, [timestep])
|
2023-07-30 15:55:22 +00:00
|
|
|
push!(predictions, (sample, x0_pred, timestep))
|
2023-07-05 19:13:26 +00:00
|
|
|
end
|
|
|
|
|
2023-08-01 20:59:12 +00:00
|
|
|
anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1)
|
2023-07-30 15:55:22 +00:00
|
|
|
if i == 0
|
|
|
|
p1 = scatter(dataset[1, :], dataset[2, :],
|
|
|
|
alpha=0.01,
|
|
|
|
aspectratio=:equal,
|
2023-08-04 19:48:27 +00:00
|
|
|
title=L"\hat{x}_t",
|
2023-07-30 15:55:22 +00:00
|
|
|
legend=false,
|
|
|
|
)
|
|
|
|
scatter!(sample_old[1, :], sample_old[2, :])
|
|
|
|
|
|
|
|
p2 = scatter(dataset[1, :], dataset[2, :],
|
|
|
|
alpha=0.01,
|
|
|
|
aspectratio=:equal,
|
2023-08-04 19:48:27 +00:00
|
|
|
title=L"\hat{x}_0",
|
2023-07-30 15:55:22 +00:00
|
|
|
legend=false,
|
|
|
|
)
|
|
|
|
|
|
|
|
l = @layout [a b]
|
|
|
|
t_str = lpad(num_timesteps, 3, "0")
|
|
|
|
plot(p1, p2,
|
|
|
|
layout=l,
|
|
|
|
plot_title=latexstring("t = $(t_str)"),
|
2023-08-04 19:48:27 +00:00
|
|
|
size=(700, 400),
|
2023-07-30 15:55:22 +00:00
|
|
|
)
|
|
|
|
xlims!(-2, 2)
|
|
|
|
ylims!(-2, 2)
|
|
|
|
else
|
|
|
|
sample, x_0, timestep = predictions[i]
|
|
|
|
p1 = scatter(dataset[1, :], dataset[2, :],
|
|
|
|
alpha=0.01,
|
|
|
|
aspectratio=:equal,
|
|
|
|
legend=false,
|
2023-08-04 19:48:27 +00:00
|
|
|
title=L"\hat{x}_t",
|
2023-07-30 15:55:22 +00:00
|
|
|
)
|
|
|
|
scatter!(sample[1, :], sample[2, :])
|
|
|
|
|
|
|
|
p2 = scatter(dataset[1, :], dataset[2, :],
|
|
|
|
alpha=0.01,
|
|
|
|
aspectratio=:equal,
|
|
|
|
legend=false,
|
2023-08-04 19:48:27 +00:00
|
|
|
title=L"\hat{x}_0",
|
2023-07-30 15:55:22 +00:00
|
|
|
)
|
|
|
|
scatter!(x_0[1, :], x_0[2, :])
|
|
|
|
|
|
|
|
l = @layout [a b]
|
|
|
|
t_str = lpad(timestep - 1, 3, "0")
|
|
|
|
plot(p1, p2,
|
|
|
|
layout=l,
|
|
|
|
plot_title=latexstring("t = $(t_str)"),
|
2023-08-04 19:48:27 +00:00
|
|
|
size=(700, 400),
|
2023-07-30 15:55:22 +00:00
|
|
|
)
|
|
|
|
xlims!(-2, 2)
|
|
|
|
ylims!(-2, 2)
|
|
|
|
end
|
|
|
|
end
|
2023-08-01 20:59:12 +00:00
|
|
|
gif(anim, anim.dir * ".gif", fps=20)
|