mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-11-09 15:02:02 +00:00
🎉 yeah boi we diffusin'
This commit is contained in:
parent
a71c40dc6d
commit
a564a8f6f6
|
@ -2,6 +2,7 @@ import Diffusers
|
||||||
using Flux
|
using Flux
|
||||||
using Random
|
using Random
|
||||||
using Plots
|
using Plots
|
||||||
|
using ProgressMeter
|
||||||
|
|
||||||
function make_spiral(rng::AbstractRNG, n_samples::Int=1000)
|
function make_spiral(rng::AbstractRNG, n_samples::Int=1000)
|
||||||
t_min = 1.5π
|
t_min = 1.5π
|
||||||
|
@ -42,7 +43,7 @@ scheduler = Diffusers.DDPM(
|
||||||
|
|
||||||
noise = randn(size(data))
|
noise = randn(size(data))
|
||||||
|
|
||||||
anim = @animate for i in cat(collect(1:num_timesteps), repeat([num_timesteps], 50), dims=1)
|
anim = @animate for i in cat(1:num_timesteps, repeat([num_timesteps], 50), dims=1)
|
||||||
noisy_data = Diffusers.add_noise(scheduler, data, noise, [i])
|
noisy_data = Diffusers.add_noise(scheduler, data, noise, [i])
|
||||||
scatter(noise[1, :], noise[2, :],
|
scatter(noise[1, :], noise[2, :],
|
||||||
alpha=0.1,
|
alpha=0.1,
|
||||||
|
@ -99,21 +100,45 @@ model = Diffusers.ConditionalChain(
|
||||||
|
|
||||||
model(data, [100])
|
model(data, [100])
|
||||||
|
|
||||||
|
num_epochs = 100
|
||||||
num_epochs = 10
|
|
||||||
loss = Flux.Losses.mse
|
loss = Flux.Losses.mse
|
||||||
dataloader = Flux.DataLoader(X |> to_device; batchsize=32, shuffle=true);
|
opt = Flux.setup(Adam(0.0001), model)
|
||||||
|
dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true);
|
||||||
|
progress = Progress(num_epochs; desc="training", showspeed=true)
|
||||||
for epoch = 1:num_epochs
|
for epoch = 1:num_epochs
|
||||||
progress = Progress(length(data); desc="epoch $epoch/$num_epochs")
|
|
||||||
params = Flux.params(model)
|
params = Flux.params(model)
|
||||||
for data in dataloader
|
for data in dataloader
|
||||||
|
noise = randn(size(data))
|
||||||
|
timesteps = rand(2:num_timesteps, size(data)[2]) # TODO: fix start at timestep=2, bruh
|
||||||
|
noisy_data = Diffusers.add_noise(scheduler, data, noise, timesteps)
|
||||||
grads = Flux.gradient(model) do m
|
grads = Flux.gradient(model) do m
|
||||||
model_output = m(data)
|
model_output = m(noisy_data, timesteps)
|
||||||
noise_prediction = Diffusers.step(model_output, timesteps, scheduler)
|
noise_prediction = Diffusers.step(scheduler, noisy_data, model_output, timesteps)
|
||||||
loss(noise, noise_prediction)
|
loss(noise, noise_prediction)
|
||||||
end
|
end
|
||||||
Flux.update!(opt, params, grads)
|
Flux.update!(opt, params, grads)
|
||||||
ProgressMeter.next!(progress; showvalues=[("batch loss", @sprintf("%.5f", batch_loss))])
|
|
||||||
end
|
end
|
||||||
|
ProgressMeter.next!(progress)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# sampling animation
|
||||||
|
anim = @animate for timestep in num_timesteps:-1:2
|
||||||
|
model_output = model(data, [timestep])
|
||||||
|
sampled_data = Diffusers.step(scheduler, data, model_output, [timestep])
|
||||||
|
scatter(sampled_data[1, :], sampled_data[2, :],
|
||||||
|
alpha=0.5,
|
||||||
|
aspectratio=:equal,
|
||||||
|
label="sampled data",
|
||||||
|
)
|
||||||
|
scatter!(data[1, :], data[2, :],
|
||||||
|
alpha=0.5,
|
||||||
|
aspectratio=:equal,
|
||||||
|
label="data",
|
||||||
|
)
|
||||||
|
i_str = lpad(timestep, 3, "0")
|
||||||
|
title!("t = $(i_str)")
|
||||||
|
xlims!(-3, 3)
|
||||||
|
ylims!(-3, 3)
|
||||||
|
end
|
||||||
|
|
||||||
|
gif(anim, "sampling.gif", fps=30)
|
||||||
|
|
|
@ -21,14 +21,16 @@ function add_noise(
|
||||||
sqrt_α_cumprod_t = scheduler.sqrt_α_cumprods[timesteps]
|
sqrt_α_cumprod_t = scheduler.sqrt_α_cumprods[timesteps]
|
||||||
sqrt_one_minus_α_cumprod_t = scheduler.sqrt_one_minus_α_cumprods[timesteps]
|
sqrt_one_minus_α_cumprod_t = scheduler.sqrt_one_minus_α_cumprods[timesteps]
|
||||||
|
|
||||||
sqrt_α_cumprod_t .* clean_data .+ sqrt_one_minus_α_cumprod_t .* noise
|
noisy_data = sqrt_α_cumprod_t' .* clean_data + sqrt_one_minus_α_cumprod_t' .* noise
|
||||||
|
|
||||||
|
return noisy_data
|
||||||
end
|
end
|
||||||
|
|
||||||
function step(
|
function step(
|
||||||
scheduler::Scheduler,
|
scheduler::Scheduler,
|
||||||
sample::AbstractArray,
|
sample::AbstractArray,
|
||||||
model_output::AbstractArray,
|
model_output::AbstractArray,
|
||||||
timestep::Int,
|
timesteps::AbstractArray,
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
Remove noise from model output using the backward diffusion process.
|
Remove noise from model output using the backward diffusion process.
|
||||||
|
@ -37,36 +39,37 @@ function step(
|
||||||
scheduler (`Scheduler`): scheduler object.
|
scheduler (`Scheduler`): scheduler object.
|
||||||
sample (`AbstractArray`): sample to remove noise from, i.e. model_input.
|
sample (`AbstractArray`): sample to remove noise from, i.e. model_input.
|
||||||
model_output (`AbstractArray`): predicted noise from the model.
|
model_output (`AbstractArray`): predicted noise from the model.
|
||||||
timestep (`Int`): timestep to remove noise from.
|
timesteps (`AbstractArray`): timesteps to remove noise from.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`AbstractArray`: denoised model output at the given timestep.
|
`AbstractArray`: denoised model output at the given timestep.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 1. compute alphas, betas
|
# 1. compute alphas, betas
|
||||||
α_cumprod_t = scheduler.α_cumprods[timestep]
|
α_cumprod_t = scheduler.α_cumprods[timesteps]
|
||||||
α_cumprod_t_prev = scheduler.α_cumprods[timestep - 1]
|
α_cumprod_t_prev = scheduler.α_cumprods[timesteps.-1]
|
||||||
β_cumprod_t = 1 - α_cumprod_t
|
β_cumprod_t = 1 .- α_cumprod_t
|
||||||
β_cumprod_t_prev = 1 - α_cumprod_t_prev
|
β_cumprod_t_prev = 1 .- α_cumprod_t_prev
|
||||||
current_α_t = α_cumprod_t / α_cumprod_t_prev
|
current_α_t = α_cumprod_t ./ α_cumprod_t_prev
|
||||||
current_β_t = 1 - current_α_t
|
current_β_t = 1 .- current_α_t
|
||||||
|
|
||||||
# 2. compute predicted original sample from predicted noise also called
|
# 2. compute predicted original sample from predicted noise also called
|
||||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||||
# epsilon prediction type
|
# epsilon prediction type
|
||||||
x_0 = (noise - √β_cumprod_t_prev * model_output) / √α_cumprod_t_prev
|
# print shapes of thingies
|
||||||
|
x_0_pred = (sample - sqrt.(β_cumprod_t)' .* model_output) ./ sqrt.(α_cumprod_t)'
|
||||||
|
|
||||||
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
||||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||||
pred_original_sample_coeff = (√α_cumprod_t_prev * current_β_t) / β_cumprod_t
|
pred_original_sample_coeff = (sqrt.(α_cumprod_t_prev) .* current_β_t) ./ β_cumprod_t
|
||||||
current_sample_coeff = √current_α_t * β_cumprod_t_prev / β_cumprod_t
|
current_sample_coeff = sqrt.(current_α_t) .* β_cumprod_t_prev ./ β_cumprod_t
|
||||||
|
|
||||||
# 5. Compute predicted previous sample µ_t
|
# 5. Compute predicted previous sample µ_t
|
||||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||||
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
|
pred_prev_sample = pred_original_sample_coeff' .* x_0_pred + current_sample_coeff' .* sample
|
||||||
|
|
||||||
# 6. Add noise
|
# 6. Add noise
|
||||||
variance = √scheduler.βs[timestep] * randn(size(model_output))
|
variance = sqrt.(scheduler.βs[timesteps])' .* randn(size(model_output))
|
||||||
pred_prev_sample = pred_prev_sample + variance
|
pred_prev_sample = pred_prev_sample + variance
|
||||||
|
|
||||||
return pred_prev_sample, pred_original_sample
|
return pred_prev_sample
|
||||||
|
end
|
||||||
|
|
Loading…
Reference in a new issue