♻️ (examples) rework swissroll

 add DenoisingDiffusion.jl
This commit is contained in:
Laureηt 2023-07-30 17:55:22 +02:00
parent 9ac0b53005
commit 9d5201d068
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
5 changed files with 146 additions and 228 deletions

View file

@ -4,7 +4,9 @@ authors = ["Laurent Fainsin <laurent@fainsin.bzh>"]
version = "0.1.0"
[deps]
DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"

View file

@ -1,99 +0,0 @@
using Flux
import Flux._show_children
import Flux._big_show
abstract type AbstractParallel end
_maybe_forward(layer::AbstractParallel, x::AbstractArray, ys::AbstractArray...) = layer(x, ys...)
_maybe_forward(layer::Parallel, x::AbstractArray, ys::AbstractArray...) = layer(x, ys...)
_maybe_forward(layer, x::AbstractArray, ys::AbstractArray...) = layer(x)
"""
ConditionalChain(layers...)
Based off `Flux.Chain` except takes in multiple inputs.
If a layer is of type `AbstractParallel` it uses all inputs else it uses only the first one.
The first input can therefore be conditioned on the other inputs.
"""
struct ConditionalChain{T<:Union{Tuple,NamedTuple}} <: AbstractParallel
layers::T
end
Flux.@functor ConditionalChain
ConditionalChain(xs...) = ConditionalChain(xs)
function ConditionalChain(; kw...)
:layers in keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
isempty(kw) && return ConditionalChain(())
ConditionalChain(values(kw))
end
Flux.@forward ConditionalChain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys, Base.firstindex
Base.getindex(c::ConditionalChain, i::AbstractArray) = ConditionalChain(c.layers[i]...)
function (c::ConditionalChain)(x, ys...)
for layer in c.layers
x = _maybe_forward(layer, x, ys...)
end
x
end
function Base.show(io::IO, c::ConditionalChain)
print(io, "ConditionalChain(")
Flux._show_layers(io, c.layers)
print(io, ")")
end
function _big_show(io::IO, m::ConditionalChain{T}, indent::Integer=0, name=nothing) where {T<:NamedTuple}
println(io, " "^indent, isnothing(name) ? "" : "$name = ", "ConditionalChain(")
for k in Base.keys(m.layers)
_big_show(io, m.layers[k], indent + 2, k)
end
if indent == 0
print(io, ") ")
_big_finale(io, m)
else
println(io, " "^indent, ")", ",")
end
end
"""
ConditionalSkipConnection(layers, connection)
The output is equivalent to `connection(layers(x, ys...), x)`.
Based off Flux.SkipConnection except it passes multiple arguments to layers.
"""
struct ConditionalSkipConnection{T,F} <: AbstractParallel
layers::T
connection::F
end
Flux.@functor ConditionalSkipConnection
function (skip::ConditionalSkipConnection)(x, ys...)
skip.connection(skip.layers(x, ys...), x)
end
function Base.show(io::IO, b::ConditionalSkipConnection)
print(io, "ConditionalSkipConnection(", b.layers, ", ", b.connection, ")")
end
### Show. Copied from Flux.jl/src/layers/show.jl
for T in [
:ConditionalChain, ConditionalSkipConnection
]
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
Flux._big_show(io, x)
elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix
Flux._layer_show(io, x)
else
show(io, x)
end
end
end
_show_children(c::ConditionalChain) = c.layers

View file

@ -1,33 +0,0 @@
using Flux
struct SinusoidalPositionEmbedding{W<:AbstractArray}
weight::W
end
Flux.@functor SinusoidalPositionEmbedding
Flux.trainable(emb::SinusoidalPositionEmbedding) = () # mark it as an non-trainable array
function SinusoidalPositionEmbedding(in::Integer, out::Integer)
W = make_positional_embedding(out, in)
SinusoidalPositionEmbedding(W)
end
function make_positional_embedding(dim_embedding::Integer, seq_length::Integer=1000; n::Integer=10000)
embedding = Matrix{Float32}(undef, dim_embedding, seq_length)
for pos in 1:seq_length
for row in 0:2:(dim_embedding-1)
denom = 1.0 / (n^(row / (dim_embedding - 2)))
embedding[row+1, pos] = sin(pos * denom)
embedding[row+2, pos] = cos(pos * denom)
end
end
embedding
end
(m::SinusoidalPositionEmbedding)(x::Integer) = m.weight[:, x]
(m::SinusoidalPositionEmbedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::SinusoidalPositionEmbedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
function Base.show(io::IO, m::SinusoidalPositionEmbedding)
print(io, "SinusoidalPositionEmbedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end

View file

@ -6,16 +6,11 @@ using Flux
using Random
using Plots
using ProgressMeter
using DenoisingDiffusion
using LaTeXStrings
# utils
include("Embeddings.jl")
include("ConditionalChain.jl")
function make_spiral(rng::AbstractRNG, n_samples::Integer=1000)
t_min = 1.5π
t_max = 4.5π
t = rand(rng, n_samples) * (t_max - t_min) .+ t_min
function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π)
t = rand(n_samples) * (t_max - t_min) .+ t_min
x = t .* cos.(t)
y = t .* sin.(t)
@ -23,8 +18,6 @@ function make_spiral(rng::AbstractRNG, n_samples::Integer=1000)
permutedims([x y], (2, 1))
end
make_spiral(n_samples::Integer=1000) = make_spiral(Random.GLOBAL_RNG, n_samples)
function normalize_zero_to_one(x)
x_min, x_max = extrema(x)
x_norm = (x .- x_min) ./ (x_max - x_min)
@ -35,9 +28,11 @@ function normalize_neg_one_to_one(x)
2 * normalize_zero_to_one(x) .- 1
end
n_samples = 1000
data = normalize_neg_one_to_one(make_spiral(n_samples))
scatter(data[1, :], data[2, :],
# make a dataset of 100 spirals
n_points = 2500
dataset = make_spiral(n_points, 1π, 5π)
dataset = normalize_neg_one_to_one(dataset)
scatter(dataset[1, :], dataset[2, :],
alpha=0.5,
aspectratio=:equal,
)
@ -46,37 +41,56 @@ num_timesteps = 100
scheduler = DDPM(
Vector{Float64},
rescale_zero_terminal_snr(
cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0)
cosine_beta_schedule(num_timesteps)
)
)
);
data = dataset[:, 1:100]
noise = randn(size(data))
anim = @animate for i in cat(1:num_timesteps, repeat([num_timesteps], 50), dims=1)
noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, [i])
scatter(noise[1, :], noise[2, :],
alpha=0.1,
aspectratio=:equal,
label="noise",
legend=:outertopright,
)
scatter!(noisy_data[1, :], noisy_data[2, :],
alpha=0.5,
aspectratio=:equal,
label="noisy data",
)
scatter!(data[1, :], data[2, :],
alpha=0.5,
aspectratio=:equal,
label="data",
)
i_str = lpad(i, 3, "0")
title!("t = $(i_str)")
xlims!(-3, 3)
ylims!(-3, 3)
anim = @animate for t in cat(fill(0, 25), 1:num_timesteps, fill(num_timesteps, 50), dims=1)
if t == 0
scatter(noise[1, :], noise[2, :],
alpha=0.3,
aspectratio=:equal,
label="noise",
legend=:outertopright,
)
scatter!(data[1, :], data[2, :],
alpha=0.3,
aspectratio=:equal,
label="data",
)
scatter!(data[1, :], data[2, :],
aspectratio=:equal,
label="noisy data",
)
title!("t = " * lpad(t, 3, "0"))
xlims!(-3, 3)
ylims!(-3, 3)
else
noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, [t])
scatter(noise[1, :], noise[2, :],
alpha=0.3,
aspectratio=:equal,
label="noise",
legend=:outertopright,
)
scatter!(data[1, :], data[2, :],
alpha=0.3,
aspectratio=:equal,
label="data",
)
scatter!(noisy_data[1, :], noisy_data[2, :],
aspectratio=:equal,
label="noisy data",
)
title!(latexstring("t = " * lpad(t, 3, "0")))
xlims!(-3, 3)
ylims!(-3, 3)
end
end
gif(anim, "swissroll.gif", fps=50)
gif(anim, anim.dir * ".gif", fps=50)
d_hid = 32
model = ConditionalChain(
@ -109,16 +123,16 @@ model = ConditionalChain(
model(data, [100])
num_epochs = 100
loss = Flux.Losses.mse
opt = Flux.setup(Adam(0.0001), model)
dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true);
progress = Progress(num_epochs; desc="training", showspeed=true)
num_epochs = 100;
loss = Flux.Losses.mse;
opt = Flux.setup(Adam(0.0001), model);
dataloader = Flux.DataLoader(dataset |> cpu; batchsize=32, shuffle=true);
progress = Progress(num_epochs; desc="training", showspeed=true);
for epoch = 1:num_epochs
params = Flux.params(model)
for data in dataloader
noise = randn(size(data))
timesteps = rand(2:num_timesteps, size(data)[2]) # TODO: fix start at timestep=2, bruh
timesteps = rand(2:num_timesteps, size(data, ndims(data))) # TODO: fix start at timestep=2, bruh
noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, timesteps)
grads = Flux.gradient(model) do m
model_output = m(noisy_data, timesteps)
@ -130,43 +144,68 @@ for epoch = 1:num_epochs
ProgressMeter.next!(progress)
end
# sampling animation
anim = @animate for timestep in num_timesteps:-1:2
## sampling animation
sample = randn(2, 100)
sample_old = sample
predictions = []
anim = for timestep in num_timesteps:-1:1
model_output = model(data, [timestep])
sampled_data, x0_pred = Diffusers.Schedulers.step(scheduler, data, model_output, [timestep])
p1 = scatter(sampled_data[1, :], sampled_data[2, :],
alpha=0.5,
aspectratio=:equal,
label="sampled data",
legend=false,
)
scatter!(data[1, :], data[2, :],
alpha=0.5,
aspectratio=:equal,
label="data",
)
p2 = scatter(x0_pred[1, :], x0_pred[2, :],
alpha=0.5,
aspectratio=:equal,
label="sampled data",
legend=false,
)
scatter!(data[1, :], data[2, :],
alpha=0.5,
aspectratio=:equal,
label="data",
)
l = @layout [a b]
i_str = lpad(timestep, 3, "0")
plot(p1, p2,
layout=l,
plot_title="t = $(i_str)",
)
xlims!(-2, 2)
ylims!(-2, 2)
sample, x0_pred = Diffusers.Schedulers.step(scheduler, data, model_output, [timestep])
push!(predictions, (sample, x0_pred, timestep))
end
gif(anim, "sampling.gif", fps=30)
anim = @animate for i in cat(fill(0, 50), 1:num_timesteps, fill(num_timesteps, 50), dims=1)
if i == 0
p1 = scatter(dataset[1, :], dataset[2, :],
alpha=0.01,
aspectratio=:equal,
title=L"x_t",
legend=false,
)
scatter!(sample_old[1, :], sample_old[2, :])
p2 = scatter(dataset[1, :], dataset[2, :],
alpha=0.01,
aspectratio=:equal,
title=L"x_0",
legend=false,
)
l = @layout [a b]
t_str = lpad(num_timesteps, 3, "0")
plot(p1, p2,
layout=l,
plot_title=latexstring("t = $(t_str)"),
)
xlims!(-2, 2)
ylims!(-2, 2)
else
sample, x_0, timestep = predictions[i]
p1 = scatter(dataset[1, :], dataset[2, :],
alpha=0.01,
aspectratio=:equal,
legend=false,
title=L"x_t",
)
scatter!(sample[1, :], sample[2, :])
p2 = scatter(dataset[1, :], dataset[2, :],
alpha=0.01,
aspectratio=:equal,
legend=false,
title=L"x_0",
)
scatter!(x_0[1, :], x_0[2, :])
l = @layout [a b]
t_str = lpad(timestep - 1, 3, "0")
plot(p1, p2,
layout=l,
plot_title=latexstring("t = $(t_str)"),
)
xlims!(-2, 2)
ylims!(-2, 2)
end
end
gif(anim, anim.dir * ".gif", fps=50)

View file

@ -87,12 +87,17 @@ function add_noise(
ϵ::AbstractArray,
t::AbstractArray,
)
⎷α̅ₜ = scheduler.⎷α̅[t]
⎷β̅ₜ = scheduler.⎷β̅[t]
# retreive scheduler variables at timesteps t
reshape_size = tuple(
fill(1, ndims(x₀) - 1)...,
size(t, 1)
)
⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size)
⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size)
# noisify clean data
# arxiv:2006.11239 Eq. 4
xₜ = ⎷α̅ₜ' .* x₀ + ⎷β̅ₜ' .* ϵ
xₜ = ⎷α̅ₜ .* x₀ + ⎷β̅ₜ .* ϵ
return xₜ
end
@ -117,31 +122,35 @@ function step(
t::AbstractArray,
)
# retreive scheduler variables at timesteps t
βₜ = scheduler.β[t]
β̅ₜ = scheduler.β̅[t]
β̅ₜ₋₁ = scheduler.β̅₋₁[t]
⎷αₜ = scheduler.⎷α[t]
⎷α̅ₜ = scheduler.⎷α̅[t]
⎷α̅ₜ₋₁ = scheduler.⎷α̅₋₁[t]
⎷β̅ₜ = scheduler.⎷β̅[t]
reshape_size = tuple(
fill(1, ndims(xₜ) - 1)...,
size(t, 1)
)
βₜ = reshape(scheduler.β[t], reshape_size)
β̅ₜ = reshape(scheduler.β̅[t], reshape_size)
β̅ₜ₋₁ = reshape(scheduler.β̅₋₁[t], reshape_size)
⎷αₜ = reshape(scheduler.⎷α[t], reshape_size)
⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size)
⎷α̅ₜ₋₁ = reshape(scheduler.⎷α̅₋₁[t], reshape_size)
⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size)
# compute predicted previous sample x̂₀
# arxiv:2006.11239 Eq. 15
# arxiv:2208.11970 Eq. 115
x̂₀ = (xₜ - ⎷β̅ₜ' .* ϵᵧ) ./ ⎷α̅ₜ'
x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ
# compute predicted previous sample μ̃ₜ
# arxiv:2006.11239 Eq. 7
# arxiv:2208.11970 Eq. 84
λ₀ = ⎷α̅ₜ₋₁ .* βₜ ./ β̅ₜ
λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ # TODO: this could be stored in the scheduler
μ̃ₜ = λ₀' .* x̂₀ + λₜ' .* xₜ
μ̃ₜ = λ₀ .* x̂₀ + λₜ .* xₜ
# sample predicted previous sample xₜ₋₁
# arxiv:2006.11239 Eq. 6
# arxiv:2208.11970 Eq. 70
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler
xₜ₋₁ = μ̃ₜ + σₜ' .* randn(size(ϵᵧ))
xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ))
return xₜ₋₁, x̂₀
end