🎉 yeah boi we diffusin'

This commit is contained in:
Laureηt 2023-07-23 14:48:15 +02:00
parent a71c40dc6d
commit a564a8f6f6
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
2 changed files with 52 additions and 24 deletions

View file

@ -2,6 +2,7 @@ import Diffusers
using Flux using Flux
using Random using Random
using Plots using Plots
using ProgressMeter
function make_spiral(rng::AbstractRNG, n_samples::Int=1000) function make_spiral(rng::AbstractRNG, n_samples::Int=1000)
t_min = 1.5π t_min = 1.5π
@ -42,7 +43,7 @@ scheduler = Diffusers.DDPM(
noise = randn(size(data)) noise = randn(size(data))
anim = @animate for i in cat(collect(1:num_timesteps), repeat([num_timesteps], 50), dims=1) anim = @animate for i in cat(1:num_timesteps, repeat([num_timesteps], 50), dims=1)
noisy_data = Diffusers.add_noise(scheduler, data, noise, [i]) noisy_data = Diffusers.add_noise(scheduler, data, noise, [i])
scatter(noise[1, :], noise[2, :], scatter(noise[1, :], noise[2, :],
alpha=0.1, alpha=0.1,
@ -99,21 +100,45 @@ model = Diffusers.ConditionalChain(
model(data, [100]) model(data, [100])
num_epochs = 100
num_epochs = 10
loss = Flux.Losses.mse loss = Flux.Losses.mse
dataloader = Flux.DataLoader(X |> to_device; batchsize=32, shuffle=true); opt = Flux.setup(Adam(0.0001), model)
dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true);
progress = Progress(num_epochs; desc="training", showspeed=true)
for epoch = 1:num_epochs for epoch = 1:num_epochs
progress = Progress(length(data); desc="epoch $epoch/$num_epochs")
params = Flux.params(model) params = Flux.params(model)
for data in dataloader for data in dataloader
noise = randn(size(data))
timesteps = rand(2:num_timesteps, size(data)[2]) # TODO: fix start at timestep=2, bruh
noisy_data = Diffusers.add_noise(scheduler, data, noise, timesteps)
grads = Flux.gradient(model) do m grads = Flux.gradient(model) do m
model_output = m(data) model_output = m(noisy_data, timesteps)
noise_prediction = Diffusers.step(model_output, timesteps, scheduler) noise_prediction = Diffusers.step(scheduler, noisy_data, model_output, timesteps)
loss(noise, noise_prediction) loss(noise, noise_prediction)
end end
Flux.update!(opt, params, grads) Flux.update!(opt, params, grads)
ProgressMeter.next!(progress; showvalues=[("batch loss", @sprintf("%.5f", batch_loss))])
end end
ProgressMeter.next!(progress)
end end
# sampling animation
anim = @animate for timestep in num_timesteps:-1:2
model_output = model(data, [timestep])
sampled_data = Diffusers.step(scheduler, data, model_output, [timestep])
scatter(sampled_data[1, :], sampled_data[2, :],
alpha=0.5,
aspectratio=:equal,
label="sampled data",
)
scatter!(data[1, :], data[2, :],
alpha=0.5,
aspectratio=:equal,
label="data",
)
i_str = lpad(timestep, 3, "0")
title!("t = $(i_str)")
xlims!(-3, 3)
ylims!(-3, 3)
end
gif(anim, "sampling.gif", fps=30)

View file

@ -21,14 +21,16 @@ function add_noise(
sqrt_α_cumprod_t = scheduler.sqrt_α_cumprods[timesteps] sqrt_α_cumprod_t = scheduler.sqrt_α_cumprods[timesteps]
sqrt_one_minus_α_cumprod_t = scheduler.sqrt_one_minus_α_cumprods[timesteps] sqrt_one_minus_α_cumprod_t = scheduler.sqrt_one_minus_α_cumprods[timesteps]
sqrt_α_cumprod_t .* clean_data .+ sqrt_one_minus_α_cumprod_t .* noise noisy_data = sqrt_α_cumprod_t' .* clean_data + sqrt_one_minus_α_cumprod_t' .* noise
return noisy_data
end end
function step( function step(
scheduler::Scheduler, scheduler::Scheduler,
sample::AbstractArray, sample::AbstractArray,
model_output::AbstractArray, model_output::AbstractArray,
timestep::Int, timesteps::AbstractArray,
) )
""" """
Remove noise from model output using the backward diffusion process. Remove noise from model output using the backward diffusion process.
@ -37,36 +39,37 @@ function step(
scheduler (`Scheduler`): scheduler object. scheduler (`Scheduler`): scheduler object.
sample (`AbstractArray`): sample to remove noise from, i.e. model_input. sample (`AbstractArray`): sample to remove noise from, i.e. model_input.
model_output (`AbstractArray`): predicted noise from the model. model_output (`AbstractArray`): predicted noise from the model.
timestep (`Int`): timestep to remove noise from. timesteps (`AbstractArray`): timesteps to remove noise from.
Returns: Returns:
`AbstractArray`: denoised model output at the given timestep. `AbstractArray`: denoised model output at the given timestep.
""" """
# 1. compute alphas, betas # 1. compute alphas, betas
α_cumprod_t = scheduler.α_cumprods[timestep] α_cumprod_t = scheduler.α_cumprods[timesteps]
α_cumprod_t_prev = scheduler.α_cumprods[timestep - 1] α_cumprod_t_prev = scheduler.α_cumprods[timesteps.-1]
β_cumprod_t = 1 - α_cumprod_t β_cumprod_t = 1 .- α_cumprod_t
β_cumprod_t_prev = 1 - α_cumprod_t_prev β_cumprod_t_prev = 1 .- α_cumprod_t_prev
current_α_t = α_cumprod_t / α_cumprod_t_prev current_α_t = α_cumprod_t ./ α_cumprod_t_prev
current_β_t = 1 - current_α_t current_β_t = 1 .- current_α_t
# 2. compute predicted original sample from predicted noise also called # 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
# epsilon prediction type # epsilon prediction type
x_0 = (noise - β_cumprod_t_prev * model_output) / α_cumprod_t_prev # print shapes of thingies
x_0_pred = (sample - sqrt.(β_cumprod_t)' .* model_output) ./ sqrt.(α_cumprod_t)'
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (α_cumprod_t_prev * current_β_t) / β_cumprod_t pred_original_sample_coeff = (sqrt.(α_cumprod_t_prev) .* current_β_t) ./ β_cumprod_t
current_sample_coeff = current_α_t * β_cumprod_t_prev / β_cumprod_t current_sample_coeff = sqrt.(current_α_t) .* β_cumprod_t_prev ./ β_cumprod_t
# 5. Compute predicted previous sample µ_t # 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample
# 6. Add noise # 6. Add noise
variance = scheduler.βs[timestep] * randn(size(model_output)) variance = sqrt.(scheduler.βs[timesteps])' .* randn(size(model_output))
pred_prev_sample = pred_prev_sample + variance pred_prev_sample = pred_prev_sample + variance
return pred_prev_sample, pred_original_sample return pred_prev_sample
end