From 71ae55da7194c371f65fbd772b3c226a9274b6dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Tue, 1 Aug 2023 22:59:12 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=91=20(swissroll)=20fix=20wrong=20trai?= =?UTF-8?q?ning=20objective=20+=20wrong=20sampling=20=F0=9F=92=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/swissroll.jl | 52 +++++++++++++++++++----------------------- src/Schedulers/DDPM.jl | 1 + 2 files changed, 24 insertions(+), 29 deletions(-) diff --git a/examples/swissroll.jl b/examples/swissroll.jl index 1122949..b8802ac 100644 --- a/examples/swissroll.jl +++ b/examples/swissroll.jl @@ -28,9 +28,8 @@ function normalize_neg_one_to_one(x) 2 * normalize_zero_to_one(x) .- 1 end -# make a dataset of 100 spirals -n_points = 2500 -dataset = make_spiral(n_points, 1π, 5π) +n_points = 1000 +dataset = make_spiral(n_points, 1.5π, 4.5π) dataset = normalize_neg_one_to_one(dataset) scatter(dataset[1, :], dataset[2, :], alpha=0.5, @@ -39,16 +38,13 @@ scatter(dataset[1, :], dataset[2, :], num_timesteps = 100 scheduler = DDPM( - Vector{Float64}, - rescale_zero_terminal_snr( - cosine_beta_schedule(num_timesteps) - ) + Vector{Float32}, + cosine_beta_schedule(num_timesteps) ); -data = dataset[:, 1:100] -noise = randn(size(data)) - -anim = @animate for t in cat(fill(0, 25), 1:num_timesteps, fill(num_timesteps, 50), dims=1) +data = dataset +noise = randn(Float32, size(data)) +anim = @animate for t in cat(fill(0, 2), 1:num_timesteps, fill(num_timesteps, 2), dims=1) if t == 0 scatter(noise[1, :], noise[2, :], alpha=0.3, @@ -65,7 +61,7 @@ anim = @animate for t in cat(fill(0, 25), 1:num_timesteps, fill(num_timesteps, 5 aspectratio=:equal, label="noisy data", ) - title!("t = " * lpad(t, 3, "0")) + title!(latexstring("t = " * lpad(t, 3, "0"))) xlims!(-3, 3) ylims!(-3, 3) else @@ -90,7 +86,7 @@ anim = @animate for t in cat(fill(0, 25), 1:num_timesteps, fill(num_timesteps, 5 ylims!(-3, 3) end end -gif(anim, anim.dir * ".gif", fps=50) +gif(anim, anim.dir * ".gif", fps=2) d_hid = 32 model = ConditionalChain( @@ -121,41 +117,39 @@ model = ConditionalChain( Dense(d_hid, 2), ) -model(data, [100]) +model(data, [5]) -num_epochs = 100; +num_epochs = 10000; loss = Flux.Losses.mse; -opt = Flux.setup(Adam(0.0001), model); -dataloader = Flux.DataLoader(dataset |> cpu; batchsize=32, shuffle=true); +opt = Flux.setup(Adam(0.001), model); +dataloader = Flux.DataLoader(dataset; batchsize=32, shuffle=true); progress = Progress(num_epochs; desc="training", showspeed=true); for epoch = 1:num_epochs - params = Flux.params(model) for data in dataloader - noise = randn(size(data)) - timesteps = rand(2:num_timesteps, size(data, ndims(data))) # TODO: fix start at timestep=2, bruh + noise = randn(Float32, size(data)) + timesteps = rand(1:num_timesteps, size(data, ndims(data))) noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, timesteps) grads = Flux.gradient(model) do m model_output = m(noisy_data, timesteps) - noise_prediction, _ = Diffusers.Schedulers.step(scheduler, noisy_data, model_output, timesteps) - loss(noise, noise_prediction) + loss(noise, model_output) end - Flux.update!(opt, params, grads) + Flux.update!(opt, model, grads[1]) end ProgressMeter.next!(progress) end ## sampling animation -sample = randn(2, 100) +sample = randn(MersenneTwister(1), Float32, 2, 100) sample_old = sample predictions = [] -anim = for timestep in num_timesteps:-1:1 - model_output = model(data, [timestep]) - sample, x0_pred = Diffusers.Schedulers.step(scheduler, data, model_output, [timestep]) +for timestep in num_timesteps:-1:1 + model_output = model(sample, [timestep]) + sample, x0_pred = Diffusers.Schedulers.step(scheduler, sample, model_output, [timestep]) push!(predictions, (sample, x0_pred, timestep)) end -anim = @animate for i in cat(fill(0, 50), 1:num_timesteps, fill(num_timesteps, 50), dims=1) +anim = @animate for i in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) if i == 0 p1 = scatter(dataset[1, :], dataset[2, :], alpha=0.01, @@ -208,4 +202,4 @@ anim = @animate for i in cat(fill(0, 50), 1:num_timesteps, fill(num_timesteps, 5 ylims!(-2, 2) end end -gif(anim, anim.dir * ".gif", fps=50) +gif(anim, anim.dir * ".gif", fps=20) diff --git a/src/Schedulers/DDPM.jl b/src/Schedulers/DDPM.jl index 846e6f1..0bd7766 100644 --- a/src/Schedulers/DDPM.jl +++ b/src/Schedulers/DDPM.jl @@ -121,6 +121,7 @@ function step( # arxiv:2006.11239 Eq. 6 # arxiv:2208.11970 Eq. 70 σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler + σₜ = exp.(log.(σₜ) ./ 2) # https://github.com/huggingface/diffusers/blob/160474ac61934cc22793d6cebea118c171175dbc/src/diffusers/schedulers/scheduling_ddpm.py#L306 xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ)) return xₜ₋₁, x̂₀