mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-11-08 22:42:01 +00:00
180 lines
5.4 KiB
Julia
180 lines
5.4 KiB
Julia
|
# This example compares the different beta schedules available in Diffusers.jl.
|
||
|
# Code related to the generation of the datasets and the plots is hidden.
|
||
|
|
||
|
using Diffusers.Schedulers: DDPM, forward # hide
|
||
|
using Diffusers.BetaSchedules # hide
|
||
|
using ProgressMeter # hide
|
||
|
using LaTeXStrings # hide
|
||
|
using Random # hide
|
||
|
using Plots # hide
|
||
|
using Flux # hide
|
||
|
using MLDatasets # hide
|
||
|
|
||
|
function normalize_zero_to_one(x) # hide
|
||
|
x_min, x_max = extrema(x) # hide
|
||
|
x_norm = (x .- x_min) ./ (x_max - x_min) # hide
|
||
|
x_norm # hide
|
||
|
end # hide
|
||
|
|
||
|
function normalize_neg_one_to_one(x) # hide
|
||
|
2 * normalize_zero_to_one(x) .- 1 # hide
|
||
|
end # hide
|
||
|
|
||
|
num_timesteps = 100 # hide
|
||
|
beta_schedules = [ # hide
|
||
|
linear_beta_schedule, # hide
|
||
|
scaled_linear_beta_schedule, # hide
|
||
|
cosine_beta_schedule, # hide
|
||
|
sigmoid_beta_schedule, # hide
|
||
|
exponential_beta_schedule, # hide
|
||
|
] # hide
|
||
|
schedulers = [ # hide
|
||
|
DDPM(collect(schedule(num_timesteps))) for schedule in beta_schedules # hide
|
||
|
]; # hide
|
||
|
|
||
|
# ## Swiss Roll
|
||
|
|
||
|
function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π) # hide
|
||
|
t = rand(typeof(t_min), n_samples) * (t_max - t_min) .+ t_min # hide
|
||
|
|
||
|
x = t .* cos.(t) # hide
|
||
|
y = t .* sin.(t) # hide
|
||
|
|
||
|
permutedims([x y], (2, 1)) # hide
|
||
|
end # hide
|
||
|
|
||
|
n_points = 1000; # hide
|
||
|
dataset = make_spiral(n_points, 1.5f0 * π, 4.5f0 * π); # hide
|
||
|
dataset = normalize_neg_one_to_one(dataset); # hide
|
||
|
|
||
|
noise = randn(Float32, size(dataset)) # hide
|
||
|
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
|
||
|
plots = [] # hide
|
||
|
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
|
||
|
if t == 0 # hide
|
||
|
scatter(dataset[1, :], dataset[2, :], # hide
|
||
|
alpha=0.5, # hide
|
||
|
aspectratio=:equal, # hide
|
||
|
legend=false, # hide
|
||
|
) # hide
|
||
|
plot = scatter!(dataset[1, :], dataset[2, :], # hide
|
||
|
alpha=0.5, # hide
|
||
|
aspectratio=:equal, # hide
|
||
|
) # hide
|
||
|
title!(string(beta_schedule)) # hide
|
||
|
xlims!(-3, 3) # hide
|
||
|
ylims!(-3, 3) # hide
|
||
|
else # hide
|
||
|
scatter(dataset[1, :], dataset[2, :], # hide
|
||
|
alpha=0.5, # hide
|
||
|
aspectratio=:equal, # hide
|
||
|
legend=false, # hide
|
||
|
) # hide
|
||
|
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
|
||
|
plot = scatter!(noisy_data[1, :], noisy_data[2, :], # hide
|
||
|
alpha=0.5, # hide
|
||
|
aspectratio=:equal, # hide
|
||
|
) # hide
|
||
|
title!(string(beta_schedule)) # hide
|
||
|
xlims!(-3, 3) # hide
|
||
|
ylims!(-3, 3) # hide
|
||
|
end # hide
|
||
|
push!(plots, plot) # hide
|
||
|
end # hide
|
||
|
plot(plots...; size=(1200, 800)) # hide
|
||
|
end # hide
|
||
|
|
||
|
gif(anim, anim.dir * ".gif", fps=20) # hide
|
||
|
|
||
|
# ## Double Square
|
||
|
|
||
|
function make_square(n_samples::Integer=1000) # hide
|
||
|
x = rand(n_samples) .* 2 .- 1 # hide
|
||
|
y = rand(n_samples) .* 2 .- 1 # hide
|
||
|
p = permutedims([x y], (2, 1)) # hide
|
||
|
p ./ maximum(abs.(p), dims=1) # hide
|
||
|
end # hide
|
||
|
|
||
|
dataset = hcat( # hide
|
||
|
make_square(Int(n_points / 2)) ./ 2 .- 1.5, # hide
|
||
|
make_square(Int(n_points / 2)) ./ 2 .+ 1.5 # hide
|
||
|
) # hide
|
||
|
|
||
|
noise = randn(Float32, size(dataset)) # hide
|
||
|
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
|
||
|
plots = [] # hide
|
||
|
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
|
||
|
if t == 0 # hide
|
||
|
scatter(dataset[1, :], dataset[2, :], # hide
|
||
|
alpha=0.5, # hide
|
||
|
aspectratio=:equal, # hide
|
||
|
legend=false, # hide
|
||
|
) # hide
|
||
|
plot = scatter!(dataset[1, :], dataset[2, :], # hide
|
||
|
alpha=0.5, # hide
|
||
|
aspectratio=:equal, # hide
|
||
|
) # hide
|
||
|
title!(string(beta_schedule)) # hide
|
||
|
xlims!(-3, 3) # hide
|
||
|
ylims!(-3, 3) # hide
|
||
|
else # hide
|
||
|
scatter(dataset[1, :], dataset[2, :], # hide
|
||
|
alpha=0.5, # hide
|
||
|
aspectratio=:equal, # hide
|
||
|
legend=false, # hide
|
||
|
) # hide
|
||
|
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
|
||
|
plot = scatter!(noisy_data[1, :], noisy_data[2, :], # hide
|
||
|
alpha=0.5, # hide
|
||
|
aspectratio=:equal, # hide
|
||
|
) # hide
|
||
|
title!(string(beta_schedule)) # hide
|
||
|
xlims!(-3, 3) # hide
|
||
|
ylims!(-3, 3) # hide
|
||
|
end # hide
|
||
|
push!(plots, plot) # hide
|
||
|
end # hide
|
||
|
plot(plots...; size=(1200, 800)) # hide
|
||
|
end # hide
|
||
|
|
||
|
gif(anim, anim.dir * ".gif", fps=20) # hide
|
||
|
|
||
|
# ## MNIST
|
||
|
|
||
|
dataset = MNIST(:test)[16].features # hide
|
||
|
dataset = rotl90(dataset) # hide
|
||
|
dataset = normalize_neg_one_to_one(dataset) # hide
|
||
|
|
||
|
noise = randn(Float32, size(dataset)) # hide
|
||
|
anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 20), dims=1) # hide
|
||
|
plots = [] # hide
|
||
|
for (i, (scheduler, beta_schedule)) in enumerate(zip(schedulers, beta_schedules)) # hide
|
||
|
if t == 0 # hide
|
||
|
plot = heatmap( # hide
|
||
|
dataset, # hide
|
||
|
c=:grayC, # hide
|
||
|
legend=:none, # hide
|
||
|
aspect_ratio=:equal, # hide
|
||
|
grid=false, # hide
|
||
|
axis=false # hide
|
||
|
) # hide
|
||
|
title!(string(beta_schedule)) # hide
|
||
|
else # hide
|
||
|
noisy_data = forward(scheduler, dataset, noise, [t]) # hide
|
||
|
plot = heatmap( # hide
|
||
|
noisy_data, # hide
|
||
|
c=:grayC, # hide
|
||
|
legend=:none, # hide
|
||
|
aspect_ratio=:equal, # hide
|
||
|
grid=false, # hide
|
||
|
axis=false # hide
|
||
|
) # hide
|
||
|
title!(string(beta_schedule)) # hide
|
||
|
end # hide
|
||
|
|
||
|
push!(plots, plot) # hide
|
||
|
end # hide
|
||
|
plot(plots...; size=(1200, 800)) # hide
|
||
|
end # hide
|
||
|
|
||
|
gif(anim, anim.dir * ".gif", fps=20) # hide
|