From a564a8f6f693bf135e4fe4196f15d292976c4fa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Sun, 23 Jul 2023 14:48:15 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=89=20yeah=20boi=20we=20diffusin'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/swissroll.jl | 41 +++++++++++++++++++++++++++++++++-------- src/Schedulers.jl | 35 +++++++++++++++++++---------------- 2 files changed, 52 insertions(+), 24 deletions(-) diff --git a/examples/swissroll.jl b/examples/swissroll.jl index a5f9345..ea29357 100644 --- a/examples/swissroll.jl +++ b/examples/swissroll.jl @@ -2,6 +2,7 @@ import Diffusers using Flux using Random using Plots +using ProgressMeter function make_spiral(rng::AbstractRNG, n_samples::Int=1000) t_min = 1.5π @@ -42,7 +43,7 @@ scheduler = Diffusers.DDPM( noise = randn(size(data)) -anim = @animate for i in cat(collect(1:num_timesteps), repeat([num_timesteps], 50), dims=1) +anim = @animate for i in cat(1:num_timesteps, repeat([num_timesteps], 50), dims=1) noisy_data = Diffusers.add_noise(scheduler, data, noise, [i]) scatter(noise[1, :], noise[2, :], alpha=0.1, @@ -99,21 +100,45 @@ model = Diffusers.ConditionalChain( model(data, [100]) - -num_epochs = 10 +num_epochs = 100 loss = Flux.Losses.mse -dataloader = Flux.DataLoader(X |> to_device; batchsize=32, shuffle=true); +opt = Flux.setup(Adam(0.0001), model) +dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true); +progress = Progress(num_epochs; desc="training", showspeed=true) for epoch = 1:num_epochs - progress = Progress(length(data); desc="epoch $epoch/$num_epochs") params = Flux.params(model) for data in dataloader + noise = randn(size(data)) + timesteps = rand(2:num_timesteps, size(data)[2]) # TODO: fix start at timestep=2, bruh + noisy_data = Diffusers.add_noise(scheduler, data, noise, timesteps) grads = Flux.gradient(model) do m - model_output = m(data) - noise_prediction = Diffusers.step(model_output, timesteps, scheduler) + model_output = m(noisy_data, timesteps) + noise_prediction = Diffusers.step(scheduler, noisy_data, model_output, timesteps) loss(noise, noise_prediction) end Flux.update!(opt, params, grads) - ProgressMeter.next!(progress; showvalues=[("batch loss", @sprintf("%.5f", batch_loss))]) end + ProgressMeter.next!(progress) end +# sampling animation +anim = @animate for timestep in num_timesteps:-1:2 + model_output = model(data, [timestep]) + sampled_data = Diffusers.step(scheduler, data, model_output, [timestep]) + scatter(sampled_data[1, :], sampled_data[2, :], + alpha=0.5, + aspectratio=:equal, + label="sampled data", + ) + scatter!(data[1, :], data[2, :], + alpha=0.5, + aspectratio=:equal, + label="data", + ) + i_str = lpad(timestep, 3, "0") + title!("t = $(i_str)") + xlims!(-3, 3) + ylims!(-3, 3) +end + +gif(anim, "sampling.gif", fps=30) diff --git a/src/Schedulers.jl b/src/Schedulers.jl index 9f71c6f..d3ecd99 100644 --- a/src/Schedulers.jl +++ b/src/Schedulers.jl @@ -21,14 +21,16 @@ function add_noise( sqrt_α_cumprod_t = scheduler.sqrt_α_cumprods[timesteps] sqrt_one_minus_α_cumprod_t = scheduler.sqrt_one_minus_α_cumprods[timesteps] - sqrt_α_cumprod_t .* clean_data .+ sqrt_one_minus_α_cumprod_t .* noise + noisy_data = sqrt_α_cumprod_t' .* clean_data + sqrt_one_minus_α_cumprod_t' .* noise + + return noisy_data end function step( scheduler::Scheduler, sample::AbstractArray, model_output::AbstractArray, - timestep::Int, + timesteps::AbstractArray, ) """ Remove noise from model output using the backward diffusion process. @@ -37,36 +39,37 @@ function step( scheduler (`Scheduler`): scheduler object. sample (`AbstractArray`): sample to remove noise from, i.e. model_input. model_output (`AbstractArray`): predicted noise from the model. - timestep (`Int`): timestep to remove noise from. + timesteps (`AbstractArray`): timesteps to remove noise from. Returns: `AbstractArray`: denoised model output at the given timestep. """ - # 1. compute alphas, betas - α_cumprod_t = scheduler.α_cumprods[timestep] - α_cumprod_t_prev = scheduler.α_cumprods[timestep - 1] - β_cumprod_t = 1 - α_cumprod_t - β_cumprod_t_prev = 1 - α_cumprod_t_prev - current_α_t = α_cumprod_t / α_cumprod_t_prev - current_β_t = 1 - current_α_t + α_cumprod_t = scheduler.α_cumprods[timesteps] + α_cumprod_t_prev = scheduler.α_cumprods[timesteps.-1] + β_cumprod_t = 1 .- α_cumprod_t + β_cumprod_t_prev = 1 .- α_cumprod_t_prev + current_α_t = α_cumprod_t ./ α_cumprod_t_prev + current_β_t = 1 .- current_α_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # epsilon prediction type - x_0 = (noise - √β_cumprod_t_prev * model_output) / √α_cumprod_t_prev + # print shapes of thingies + x_0_pred = (sample - sqrt.(β_cumprod_t)' .* model_output) ./ sqrt.(α_cumprod_t)' # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_sample_coeff = (√α_cumprod_t_prev * current_β_t) / β_cumprod_t - current_sample_coeff = √current_α_t * β_cumprod_t_prev / β_cumprod_t + pred_original_sample_coeff = (sqrt.(α_cumprod_t_prev) .* current_β_t) ./ β_cumprod_t + current_sample_coeff = sqrt.(current_α_t) .* β_cumprod_t_prev ./ β_cumprod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample # 6. Add noise - variance = √scheduler.βs[timestep] * randn(size(model_output)) + variance = sqrt.(scheduler.βs[timesteps])' .* randn(size(model_output)) pred_prev_sample = pred_prev_sample + variance - return pred_prev_sample, pred_original_sample + return pred_prev_sample +end