From ce9a4ec323afd4b49d2da4ba8dbe1471246fdc1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Tue, 25 Jul 2023 21:01:56 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20(examples/swissroll)=20subplot=20x?= =?UTF-8?q?=5Ft-1=20and=20x=5F0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/swissroll.jl | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/examples/swissroll.jl b/examples/swissroll.jl index 2dfb83a..fc6d11a 100644 --- a/examples/swissroll.jl +++ b/examples/swissroll.jl @@ -104,7 +104,7 @@ model = ConditionalChain( model(data, [100]) -num_epochs = 100 +num_epochs = 1000 loss = Flux.Losses.mse opt = Flux.setup(Adam(0.0001), model) dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true); @@ -117,7 +117,7 @@ for epoch = 1:num_epochs noisy_data = Diffusers.add_noise(scheduler, data, noise, timesteps) grads = Flux.gradient(model) do m model_output = m(noisy_data, timesteps) - noise_prediction = Diffusers.step(scheduler, noisy_data, model_output, timesteps) + noise_prediction, _ = Diffusers.step(scheduler, noisy_data, model_output, timesteps) loss(noise, noise_prediction) end Flux.update!(opt, params, grads) @@ -128,21 +128,40 @@ end # sampling animation anim = @animate for timestep in num_timesteps:-1:2 model_output = model(data, [timestep]) - sampled_data = Diffusers.step(scheduler, data, model_output, [timestep]) - scatter(sampled_data[1, :], sampled_data[2, :], + sampled_data, x0_pred = Diffusers.step(scheduler, data, model_output, [timestep]) + + p1 = scatter(sampled_data[1, :], sampled_data[2, :], alpha=0.5, aspectratio=:equal, label="sampled data", + legend=false, ) scatter!(data[1, :], data[2, :], alpha=0.5, aspectratio=:equal, label="data", ) + + p2 = scatter(x0_pred[1, :], x0_pred[2, :], + alpha=0.5, + aspectratio=:equal, + label="sampled data", + legend=false, + ) + scatter!(data[1, :], data[2, :], + alpha=0.5, + aspectratio=:equal, + label="data", + ) + + l = @layout [a b] i_str = lpad(timestep, 3, "0") - title!("t = $(i_str)") - xlims!(-3, 3) - ylims!(-3, 3) + plot(p1, p2, + layout=l, + plot_title="t = $(i_str)", + ) + xlims!(-2, 2) + ylims!(-2, 2) end gif(anim, "sampling.gif", fps=30)