mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-09-18 18:45:28 +00:00
✨ (examples/swissroll) subplot x_t-1 and x_0
This commit is contained in:
parent
4122a47b6d
commit
ce9a4ec323
|
@ -104,7 +104,7 @@ model = ConditionalChain(
|
||||||
|
|
||||||
model(data, [100])
|
model(data, [100])
|
||||||
|
|
||||||
num_epochs = 100
|
num_epochs = 1000
|
||||||
loss = Flux.Losses.mse
|
loss = Flux.Losses.mse
|
||||||
opt = Flux.setup(Adam(0.0001), model)
|
opt = Flux.setup(Adam(0.0001), model)
|
||||||
dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true);
|
dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true);
|
||||||
|
@ -117,7 +117,7 @@ for epoch = 1:num_epochs
|
||||||
noisy_data = Diffusers.add_noise(scheduler, data, noise, timesteps)
|
noisy_data = Diffusers.add_noise(scheduler, data, noise, timesteps)
|
||||||
grads = Flux.gradient(model) do m
|
grads = Flux.gradient(model) do m
|
||||||
model_output = m(noisy_data, timesteps)
|
model_output = m(noisy_data, timesteps)
|
||||||
noise_prediction = Diffusers.step(scheduler, noisy_data, model_output, timesteps)
|
noise_prediction, _ = Diffusers.step(scheduler, noisy_data, model_output, timesteps)
|
||||||
loss(noise, noise_prediction)
|
loss(noise, noise_prediction)
|
||||||
end
|
end
|
||||||
Flux.update!(opt, params, grads)
|
Flux.update!(opt, params, grads)
|
||||||
|
@ -128,21 +128,40 @@ end
|
||||||
# sampling animation
|
# sampling animation
|
||||||
anim = @animate for timestep in num_timesteps:-1:2
|
anim = @animate for timestep in num_timesteps:-1:2
|
||||||
model_output = model(data, [timestep])
|
model_output = model(data, [timestep])
|
||||||
sampled_data = Diffusers.step(scheduler, data, model_output, [timestep])
|
sampled_data, x0_pred = Diffusers.step(scheduler, data, model_output, [timestep])
|
||||||
scatter(sampled_data[1, :], sampled_data[2, :],
|
|
||||||
|
p1 = scatter(sampled_data[1, :], sampled_data[2, :],
|
||||||
alpha=0.5,
|
alpha=0.5,
|
||||||
aspectratio=:equal,
|
aspectratio=:equal,
|
||||||
label="sampled data",
|
label="sampled data",
|
||||||
|
legend=false,
|
||||||
)
|
)
|
||||||
scatter!(data[1, :], data[2, :],
|
scatter!(data[1, :], data[2, :],
|
||||||
alpha=0.5,
|
alpha=0.5,
|
||||||
aspectratio=:equal,
|
aspectratio=:equal,
|
||||||
label="data",
|
label="data",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
p2 = scatter(x0_pred[1, :], x0_pred[2, :],
|
||||||
|
alpha=0.5,
|
||||||
|
aspectratio=:equal,
|
||||||
|
label="sampled data",
|
||||||
|
legend=false,
|
||||||
|
)
|
||||||
|
scatter!(data[1, :], data[2, :],
|
||||||
|
alpha=0.5,
|
||||||
|
aspectratio=:equal,
|
||||||
|
label="data",
|
||||||
|
)
|
||||||
|
|
||||||
|
l = @layout [a b]
|
||||||
i_str = lpad(timestep, 3, "0")
|
i_str = lpad(timestep, 3, "0")
|
||||||
title!("t = $(i_str)")
|
plot(p1, p2,
|
||||||
xlims!(-3, 3)
|
layout=l,
|
||||||
ylims!(-3, 3)
|
plot_title="t = $(i_str)",
|
||||||
|
)
|
||||||
|
xlims!(-2, 2)
|
||||||
|
ylims!(-2, 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
gif(anim, "sampling.gif", fps=30)
|
gif(anim, "sampling.gif", fps=30)
|
||||||
|
|
Loading…
Reference in a new issue