mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-09-18 18:45:28 +00:00
🚑 (swissroll) fix wrong training objective + wrong sampling 💀
This commit is contained in:
parent
42c2bcb5bb
commit
71ae55da71
|
@ -28,9 +28,8 @@ function normalize_neg_one_to_one(x)
|
|||
2 * normalize_zero_to_one(x) .- 1
|
||||
end
|
||||
|
||||
# make a dataset of 100 spirals
|
||||
n_points = 2500
|
||||
dataset = make_spiral(n_points, 1π, 5π)
|
||||
n_points = 1000
|
||||
dataset = make_spiral(n_points, 1.5π, 4.5π)
|
||||
dataset = normalize_neg_one_to_one(dataset)
|
||||
scatter(dataset[1, :], dataset[2, :],
|
||||
alpha=0.5,
|
||||
|
@ -39,16 +38,13 @@ scatter(dataset[1, :], dataset[2, :],
|
|||
|
||||
num_timesteps = 100
|
||||
scheduler = DDPM(
|
||||
Vector{Float64},
|
||||
rescale_zero_terminal_snr(
|
||||
cosine_beta_schedule(num_timesteps)
|
||||
)
|
||||
Vector{Float32},
|
||||
cosine_beta_schedule(num_timesteps)
|
||||
);
|
||||
|
||||
data = dataset[:, 1:100]
|
||||
noise = randn(size(data))
|
||||
|
||||
anim = @animate for t in cat(fill(0, 25), 1:num_timesteps, fill(num_timesteps, 50), dims=1)
|
||||
data = dataset
|
||||
noise = randn(Float32, size(data))
|
||||
anim = @animate for t in cat(fill(0, 2), 1:num_timesteps, fill(num_timesteps, 2), dims=1)
|
||||
if t == 0
|
||||
scatter(noise[1, :], noise[2, :],
|
||||
alpha=0.3,
|
||||
|
@ -65,7 +61,7 @@ anim = @animate for t in cat(fill(0, 25), 1:num_timesteps, fill(num_timesteps, 5
|
|||
aspectratio=:equal,
|
||||
label="noisy data",
|
||||
)
|
||||
title!("t = " * lpad(t, 3, "0"))
|
||||
title!(latexstring("t = " * lpad(t, 3, "0")))
|
||||
xlims!(-3, 3)
|
||||
ylims!(-3, 3)
|
||||
else
|
||||
|
@ -90,7 +86,7 @@ anim = @animate for t in cat(fill(0, 25), 1:num_timesteps, fill(num_timesteps, 5
|
|||
ylims!(-3, 3)
|
||||
end
|
||||
end
|
||||
gif(anim, anim.dir * ".gif", fps=50)
|
||||
gif(anim, anim.dir * ".gif", fps=2)
|
||||
|
||||
d_hid = 32
|
||||
model = ConditionalChain(
|
||||
|
@ -121,41 +117,39 @@ model = ConditionalChain(
|
|||
Dense(d_hid, 2),
|
||||
)
|
||||
|
||||
model(data, [100])
|
||||
model(data, [5])
|
||||
|
||||
num_epochs = 100;
|
||||
num_epochs = 10000;
|
||||
loss = Flux.Losses.mse;
|
||||
opt = Flux.setup(Adam(0.0001), model);
|
||||
dataloader = Flux.DataLoader(dataset |> cpu; batchsize=32, shuffle=true);
|
||||
opt = Flux.setup(Adam(0.001), model);
|
||||
dataloader = Flux.DataLoader(dataset; batchsize=32, shuffle=true);
|
||||
progress = Progress(num_epochs; desc="training", showspeed=true);
|
||||
for epoch = 1:num_epochs
|
||||
params = Flux.params(model)
|
||||
for data in dataloader
|
||||
noise = randn(size(data))
|
||||
timesteps = rand(2:num_timesteps, size(data, ndims(data))) # TODO: fix start at timestep=2, bruh
|
||||
noise = randn(Float32, size(data))
|
||||
timesteps = rand(1:num_timesteps, size(data, ndims(data)))
|
||||
noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, timesteps)
|
||||
grads = Flux.gradient(model) do m
|
||||
model_output = m(noisy_data, timesteps)
|
||||
noise_prediction, _ = Diffusers.Schedulers.step(scheduler, noisy_data, model_output, timesteps)
|
||||
loss(noise, noise_prediction)
|
||||
loss(noise, model_output)
|
||||
end
|
||||
Flux.update!(opt, params, grads)
|
||||
Flux.update!(opt, model, grads[1])
|
||||
end
|
||||
ProgressMeter.next!(progress)
|
||||
end
|
||||
|
||||
## sampling animation
|
||||
|
||||
sample = randn(2, 100)
|
||||
sample = randn(MersenneTwister(1), Float32, 2, 100)
|
||||
sample_old = sample
|
||||
predictions = []
|
||||
anim = for timestep in num_timesteps:-1:1
|
||||
model_output = model(data, [timestep])
|
||||
sample, x0_pred = Diffusers.Schedulers.step(scheduler, data, model_output, [timestep])
|
||||
for timestep in num_timesteps:-1:1
|
||||
model_output = model(sample, [timestep])
|
||||
sample, x0_pred = Diffusers.Schedulers.step(scheduler, sample, model_output, [timestep])
|
||||
push!(predictions, (sample, x0_pred, timestep))
|
||||
end
|
||||
|
||||
anim = @animate for i in cat(fill(0, 50), 1:num_timesteps, fill(num_timesteps, 50), dims=1)
|
||||
anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1)
|
||||
if i == 0
|
||||
p1 = scatter(dataset[1, :], dataset[2, :],
|
||||
alpha=0.01,
|
||||
|
@ -208,4 +202,4 @@ anim = @animate for i in cat(fill(0, 50), 1:num_timesteps, fill(num_timesteps, 5
|
|||
ylims!(-2, 2)
|
||||
end
|
||||
end
|
||||
gif(anim, anim.dir * ".gif", fps=50)
|
||||
gif(anim, anim.dir * ".gif", fps=20)
|
||||
|
|
|
@ -121,6 +121,7 @@ function step(
|
|||
# arxiv:2006.11239 Eq. 6
|
||||
# arxiv:2208.11970 Eq. 70
|
||||
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler
|
||||
σₜ = exp.(log.(σₜ) ./ 2) # https://github.com/huggingface/diffusers/blob/160474ac61934cc22793d6cebea118c171175dbc/src/diffusers/schedulers/scheduling_ddpm.py#L306
|
||||
xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ))
|
||||
|
||||
return xₜ₋₁, x̂₀
|
||||
|
|
Loading…
Reference in a new issue