mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-09-16 17:45:29 +00:00
✨ (examples/swissroll) subplot x_t-1 and x_0
This commit is contained in:
parent
4122a47b6d
commit
ce9a4ec323
|
@ -104,7 +104,7 @@ model = ConditionalChain(
|
|||
|
||||
model(data, [100])
|
||||
|
||||
num_epochs = 100
|
||||
num_epochs = 1000
|
||||
loss = Flux.Losses.mse
|
||||
opt = Flux.setup(Adam(0.0001), model)
|
||||
dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true);
|
||||
|
@ -117,7 +117,7 @@ for epoch = 1:num_epochs
|
|||
noisy_data = Diffusers.add_noise(scheduler, data, noise, timesteps)
|
||||
grads = Flux.gradient(model) do m
|
||||
model_output = m(noisy_data, timesteps)
|
||||
noise_prediction = Diffusers.step(scheduler, noisy_data, model_output, timesteps)
|
||||
noise_prediction, _ = Diffusers.step(scheduler, noisy_data, model_output, timesteps)
|
||||
loss(noise, noise_prediction)
|
||||
end
|
||||
Flux.update!(opt, params, grads)
|
||||
|
@ -128,21 +128,40 @@ end
|
|||
# sampling animation
|
||||
anim = @animate for timestep in num_timesteps:-1:2
|
||||
model_output = model(data, [timestep])
|
||||
sampled_data = Diffusers.step(scheduler, data, model_output, [timestep])
|
||||
scatter(sampled_data[1, :], sampled_data[2, :],
|
||||
sampled_data, x0_pred = Diffusers.step(scheduler, data, model_output, [timestep])
|
||||
|
||||
p1 = scatter(sampled_data[1, :], sampled_data[2, :],
|
||||
alpha=0.5,
|
||||
aspectratio=:equal,
|
||||
label="sampled data",
|
||||
legend=false,
|
||||
)
|
||||
scatter!(data[1, :], data[2, :],
|
||||
alpha=0.5,
|
||||
aspectratio=:equal,
|
||||
label="data",
|
||||
)
|
||||
|
||||
p2 = scatter(x0_pred[1, :], x0_pred[2, :],
|
||||
alpha=0.5,
|
||||
aspectratio=:equal,
|
||||
label="sampled data",
|
||||
legend=false,
|
||||
)
|
||||
scatter!(data[1, :], data[2, :],
|
||||
alpha=0.5,
|
||||
aspectratio=:equal,
|
||||
label="data",
|
||||
)
|
||||
|
||||
l = @layout [a b]
|
||||
i_str = lpad(timestep, 3, "0")
|
||||
title!("t = $(i_str)")
|
||||
xlims!(-3, 3)
|
||||
ylims!(-3, 3)
|
||||
plot(p1, p2,
|
||||
layout=l,
|
||||
plot_title="t = $(i_str)",
|
||||
)
|
||||
xlims!(-2, 2)
|
||||
ylims!(-2, 2)
|
||||
end
|
||||
|
||||
gif(anim, "sampling.gif", fps=30)
|
||||
|
|
Loading…
Reference in a new issue