(examples/swissroll) subplot x_t-1 and x_0

This commit is contained in:
Laureηt 2023-07-25 21:01:56 +02:00
parent 4122a47b6d
commit ce9a4ec323
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI

View file

@ -104,7 +104,7 @@ model = ConditionalChain(
model(data, [100])
num_epochs = 100
num_epochs = 1000
loss = Flux.Losses.mse
opt = Flux.setup(Adam(0.0001), model)
dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true);
@ -117,7 +117,7 @@ for epoch = 1:num_epochs
noisy_data = Diffusers.add_noise(scheduler, data, noise, timesteps)
grads = Flux.gradient(model) do m
model_output = m(noisy_data, timesteps)
noise_prediction = Diffusers.step(scheduler, noisy_data, model_output, timesteps)
noise_prediction, _ = Diffusers.step(scheduler, noisy_data, model_output, timesteps)
loss(noise, noise_prediction)
end
Flux.update!(opt, params, grads)
@ -128,21 +128,40 @@ end
# sampling animation
anim = @animate for timestep in num_timesteps:-1:2
model_output = model(data, [timestep])
sampled_data = Diffusers.step(scheduler, data, model_output, [timestep])
scatter(sampled_data[1, :], sampled_data[2, :],
sampled_data, x0_pred = Diffusers.step(scheduler, data, model_output, [timestep])
p1 = scatter(sampled_data[1, :], sampled_data[2, :],
alpha=0.5,
aspectratio=:equal,
label="sampled data",
legend=false,
)
scatter!(data[1, :], data[2, :],
alpha=0.5,
aspectratio=:equal,
label="data",
)
p2 = scatter(x0_pred[1, :], x0_pred[2, :],
alpha=0.5,
aspectratio=:equal,
label="sampled data",
legend=false,
)
scatter!(data[1, :], data[2, :],
alpha=0.5,
aspectratio=:equal,
label="data",
)
l = @layout [a b]
i_str = lpad(timestep, 3, "0")
title!("t = $(i_str)")
xlims!(-3, 3)
ylims!(-3, 3)
plot(p1, p2,
layout=l,
plot_title="t = $(i_str)",
)
xlims!(-2, 2)
ylims!(-2, 2)
end
gif(anim, "sampling.gif", fps=30)