mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-09-18 18:45:28 +00:00
♻️ (examples) rework swissroll
➕ add DenoisingDiffusion.jl
This commit is contained in:
parent
9ac0b53005
commit
9d5201d068
|
@ -4,7 +4,9 @@ authors = ["Laurent Fainsin <laurent@fainsin.bzh>"]
|
|||
version = "0.1.0"
|
||||
|
||||
[deps]
|
||||
DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef"
|
||||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
|
||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
|
||||
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
|
||||
|
|
|
@ -1,99 +0,0 @@
|
|||
using Flux
|
||||
import Flux._show_children
|
||||
import Flux._big_show
|
||||
|
||||
|
||||
abstract type AbstractParallel end
|
||||
|
||||
_maybe_forward(layer::AbstractParallel, x::AbstractArray, ys::AbstractArray...) = layer(x, ys...)
|
||||
_maybe_forward(layer::Parallel, x::AbstractArray, ys::AbstractArray...) = layer(x, ys...)
|
||||
_maybe_forward(layer, x::AbstractArray, ys::AbstractArray...) = layer(x)
|
||||
|
||||
"""
|
||||
ConditionalChain(layers...)
|
||||
|
||||
Based off `Flux.Chain` except takes in multiple inputs.
|
||||
If a layer is of type `AbstractParallel` it uses all inputs else it uses only the first one.
|
||||
The first input can therefore be conditioned on the other inputs.
|
||||
"""
|
||||
struct ConditionalChain{T<:Union{Tuple,NamedTuple}} <: AbstractParallel
|
||||
layers::T
|
||||
end
|
||||
Flux.@functor ConditionalChain
|
||||
|
||||
ConditionalChain(xs...) = ConditionalChain(xs)
|
||||
function ConditionalChain(; kw...)
|
||||
:layers in keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
|
||||
isempty(kw) && return ConditionalChain(())
|
||||
ConditionalChain(values(kw))
|
||||
end
|
||||
|
||||
Flux.@forward ConditionalChain.layers Base.getindex, Base.length, Base.first, Base.last,
|
||||
Base.iterate, Base.lastindex, Base.keys, Base.firstindex
|
||||
|
||||
Base.getindex(c::ConditionalChain, i::AbstractArray) = ConditionalChain(c.layers[i]...)
|
||||
|
||||
function (c::ConditionalChain)(x, ys...)
|
||||
for layer in c.layers
|
||||
x = _maybe_forward(layer, x, ys...)
|
||||
end
|
||||
x
|
||||
end
|
||||
|
||||
function Base.show(io::IO, c::ConditionalChain)
|
||||
print(io, "ConditionalChain(")
|
||||
Flux._show_layers(io, c.layers)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
function _big_show(io::IO, m::ConditionalChain{T}, indent::Integer=0, name=nothing) where {T<:NamedTuple}
|
||||
println(io, " "^indent, isnothing(name) ? "" : "$name = ", "ConditionalChain(")
|
||||
for k in Base.keys(m.layers)
|
||||
_big_show(io, m.layers[k], indent + 2, k)
|
||||
end
|
||||
if indent == 0
|
||||
print(io, ") ")
|
||||
_big_finale(io, m)
|
||||
else
|
||||
println(io, " "^indent, ")", ",")
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
ConditionalSkipConnection(layers, connection)
|
||||
|
||||
The output is equivalent to `connection(layers(x, ys...), x)`.
|
||||
Based off Flux.SkipConnection except it passes multiple arguments to layers.
|
||||
"""
|
||||
struct ConditionalSkipConnection{T,F} <: AbstractParallel
|
||||
layers::T
|
||||
connection::F
|
||||
end
|
||||
|
||||
Flux.@functor ConditionalSkipConnection
|
||||
|
||||
function (skip::ConditionalSkipConnection)(x, ys...)
|
||||
skip.connection(skip.layers(x, ys...), x)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, b::ConditionalSkipConnection)
|
||||
print(io, "ConditionalSkipConnection(", b.layers, ", ", b.connection, ")")
|
||||
end
|
||||
|
||||
### Show. Copied from Flux.jl/src/layers/show.jl
|
||||
|
||||
for T in [
|
||||
:ConditionalChain, ConditionalSkipConnection
|
||||
]
|
||||
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
|
||||
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
|
||||
Flux._big_show(io, x)
|
||||
elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix
|
||||
Flux._layer_show(io, x)
|
||||
else
|
||||
show(io, x)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
_show_children(c::ConditionalChain) = c.layers
|
|
@ -1,33 +0,0 @@
|
|||
using Flux
|
||||
|
||||
struct SinusoidalPositionEmbedding{W<:AbstractArray}
|
||||
weight::W
|
||||
end
|
||||
|
||||
Flux.@functor SinusoidalPositionEmbedding
|
||||
Flux.trainable(emb::SinusoidalPositionEmbedding) = () # mark it as an non-trainable array
|
||||
|
||||
function SinusoidalPositionEmbedding(in::Integer, out::Integer)
|
||||
W = make_positional_embedding(out, in)
|
||||
SinusoidalPositionEmbedding(W)
|
||||
end
|
||||
|
||||
function make_positional_embedding(dim_embedding::Integer, seq_length::Integer=1000; n::Integer=10000)
|
||||
embedding = Matrix{Float32}(undef, dim_embedding, seq_length)
|
||||
for pos in 1:seq_length
|
||||
for row in 0:2:(dim_embedding-1)
|
||||
denom = 1.0 / (n^(row / (dim_embedding - 2)))
|
||||
embedding[row+1, pos] = sin(pos * denom)
|
||||
embedding[row+2, pos] = cos(pos * denom)
|
||||
end
|
||||
end
|
||||
embedding
|
||||
end
|
||||
|
||||
(m::SinusoidalPositionEmbedding)(x::Integer) = m.weight[:, x]
|
||||
(m::SinusoidalPositionEmbedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
|
||||
(m::SinusoidalPositionEmbedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
|
||||
|
||||
function Base.show(io::IO, m::SinusoidalPositionEmbedding)
|
||||
print(io, "SinusoidalPositionEmbedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
|
||||
end
|
|
@ -6,16 +6,11 @@ using Flux
|
|||
using Random
|
||||
using Plots
|
||||
using ProgressMeter
|
||||
using DenoisingDiffusion
|
||||
using LaTeXStrings
|
||||
|
||||
# utils
|
||||
include("Embeddings.jl")
|
||||
include("ConditionalChain.jl")
|
||||
|
||||
function make_spiral(rng::AbstractRNG, n_samples::Integer=1000)
|
||||
t_min = 1.5π
|
||||
t_max = 4.5π
|
||||
|
||||
t = rand(rng, n_samples) * (t_max - t_min) .+ t_min
|
||||
function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π)
|
||||
t = rand(n_samples) * (t_max - t_min) .+ t_min
|
||||
|
||||
x = t .* cos.(t)
|
||||
y = t .* sin.(t)
|
||||
|
@ -23,8 +18,6 @@ function make_spiral(rng::AbstractRNG, n_samples::Integer=1000)
|
|||
permutedims([x y], (2, 1))
|
||||
end
|
||||
|
||||
make_spiral(n_samples::Integer=1000) = make_spiral(Random.GLOBAL_RNG, n_samples)
|
||||
|
||||
function normalize_zero_to_one(x)
|
||||
x_min, x_max = extrema(x)
|
||||
x_norm = (x .- x_min) ./ (x_max - x_min)
|
||||
|
@ -35,9 +28,11 @@ function normalize_neg_one_to_one(x)
|
|||
2 * normalize_zero_to_one(x) .- 1
|
||||
end
|
||||
|
||||
n_samples = 1000
|
||||
data = normalize_neg_one_to_one(make_spiral(n_samples))
|
||||
scatter(data[1, :], data[2, :],
|
||||
# make a dataset of 100 spirals
|
||||
n_points = 2500
|
||||
dataset = make_spiral(n_points, 1π, 5π)
|
||||
dataset = normalize_neg_one_to_one(dataset)
|
||||
scatter(dataset[1, :], dataset[2, :],
|
||||
alpha=0.5,
|
||||
aspectratio=:equal,
|
||||
)
|
||||
|
@ -46,37 +41,56 @@ num_timesteps = 100
|
|||
scheduler = DDPM(
|
||||
Vector{Float64},
|
||||
rescale_zero_terminal_snr(
|
||||
cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0)
|
||||
cosine_beta_schedule(num_timesteps)
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
data = dataset[:, 1:100]
|
||||
noise = randn(size(data))
|
||||
|
||||
anim = @animate for i in cat(1:num_timesteps, repeat([num_timesteps], 50), dims=1)
|
||||
noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, [i])
|
||||
anim = @animate for t in cat(fill(0, 25), 1:num_timesteps, fill(num_timesteps, 50), dims=1)
|
||||
if t == 0
|
||||
scatter(noise[1, :], noise[2, :],
|
||||
alpha=0.1,
|
||||
alpha=0.3,
|
||||
aspectratio=:equal,
|
||||
label="noise",
|
||||
legend=:outertopright,
|
||||
)
|
||||
scatter!(noisy_data[1, :], noisy_data[2, :],
|
||||
alpha=0.5,
|
||||
aspectratio=:equal,
|
||||
label="noisy data",
|
||||
)
|
||||
scatter!(data[1, :], data[2, :],
|
||||
alpha=0.5,
|
||||
alpha=0.3,
|
||||
aspectratio=:equal,
|
||||
label="data",
|
||||
)
|
||||
i_str = lpad(i, 3, "0")
|
||||
title!("t = $(i_str)")
|
||||
scatter!(data[1, :], data[2, :],
|
||||
aspectratio=:equal,
|
||||
label="noisy data",
|
||||
)
|
||||
title!("t = " * lpad(t, 3, "0"))
|
||||
xlims!(-3, 3)
|
||||
ylims!(-3, 3)
|
||||
else
|
||||
noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, [t])
|
||||
scatter(noise[1, :], noise[2, :],
|
||||
alpha=0.3,
|
||||
aspectratio=:equal,
|
||||
label="noise",
|
||||
legend=:outertopright,
|
||||
)
|
||||
scatter!(data[1, :], data[2, :],
|
||||
alpha=0.3,
|
||||
aspectratio=:equal,
|
||||
label="data",
|
||||
)
|
||||
scatter!(noisy_data[1, :], noisy_data[2, :],
|
||||
aspectratio=:equal,
|
||||
label="noisy data",
|
||||
)
|
||||
title!(latexstring("t = " * lpad(t, 3, "0")))
|
||||
xlims!(-3, 3)
|
||||
ylims!(-3, 3)
|
||||
end
|
||||
end
|
||||
|
||||
gif(anim, "swissroll.gif", fps=50)
|
||||
gif(anim, anim.dir * ".gif", fps=50)
|
||||
|
||||
d_hid = 32
|
||||
model = ConditionalChain(
|
||||
|
@ -109,16 +123,16 @@ model = ConditionalChain(
|
|||
|
||||
model(data, [100])
|
||||
|
||||
num_epochs = 100
|
||||
loss = Flux.Losses.mse
|
||||
opt = Flux.setup(Adam(0.0001), model)
|
||||
dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true);
|
||||
progress = Progress(num_epochs; desc="training", showspeed=true)
|
||||
num_epochs = 100;
|
||||
loss = Flux.Losses.mse;
|
||||
opt = Flux.setup(Adam(0.0001), model);
|
||||
dataloader = Flux.DataLoader(dataset |> cpu; batchsize=32, shuffle=true);
|
||||
progress = Progress(num_epochs; desc="training", showspeed=true);
|
||||
for epoch = 1:num_epochs
|
||||
params = Flux.params(model)
|
||||
for data in dataloader
|
||||
noise = randn(size(data))
|
||||
timesteps = rand(2:num_timesteps, size(data)[2]) # TODO: fix start at timestep=2, bruh
|
||||
timesteps = rand(2:num_timesteps, size(data, ndims(data))) # TODO: fix start at timestep=2, bruh
|
||||
noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, timesteps)
|
||||
grads = Flux.gradient(model) do m
|
||||
model_output = m(noisy_data, timesteps)
|
||||
|
@ -130,43 +144,68 @@ for epoch = 1:num_epochs
|
|||
ProgressMeter.next!(progress)
|
||||
end
|
||||
|
||||
# sampling animation
|
||||
anim = @animate for timestep in num_timesteps:-1:2
|
||||
## sampling animation
|
||||
|
||||
sample = randn(2, 100)
|
||||
sample_old = sample
|
||||
predictions = []
|
||||
anim = for timestep in num_timesteps:-1:1
|
||||
model_output = model(data, [timestep])
|
||||
sampled_data, x0_pred = Diffusers.Schedulers.step(scheduler, data, model_output, [timestep])
|
||||
sample, x0_pred = Diffusers.Schedulers.step(scheduler, data, model_output, [timestep])
|
||||
push!(predictions, (sample, x0_pred, timestep))
|
||||
end
|
||||
|
||||
p1 = scatter(sampled_data[1, :], sampled_data[2, :],
|
||||
alpha=0.5,
|
||||
anim = @animate for i in cat(fill(0, 50), 1:num_timesteps, fill(num_timesteps, 50), dims=1)
|
||||
if i == 0
|
||||
p1 = scatter(dataset[1, :], dataset[2, :],
|
||||
alpha=0.01,
|
||||
aspectratio=:equal,
|
||||
label="sampled data",
|
||||
title=L"x_t",
|
||||
legend=false,
|
||||
)
|
||||
scatter!(data[1, :], data[2, :],
|
||||
alpha=0.5,
|
||||
aspectratio=:equal,
|
||||
label="data",
|
||||
)
|
||||
scatter!(sample_old[1, :], sample_old[2, :])
|
||||
|
||||
p2 = scatter(x0_pred[1, :], x0_pred[2, :],
|
||||
alpha=0.5,
|
||||
p2 = scatter(dataset[1, :], dataset[2, :],
|
||||
alpha=0.01,
|
||||
aspectratio=:equal,
|
||||
label="sampled data",
|
||||
title=L"x_0",
|
||||
legend=false,
|
||||
)
|
||||
scatter!(data[1, :], data[2, :],
|
||||
alpha=0.5,
|
||||
aspectratio=:equal,
|
||||
label="data",
|
||||
)
|
||||
|
||||
l = @layout [a b]
|
||||
i_str = lpad(timestep, 3, "0")
|
||||
t_str = lpad(num_timesteps, 3, "0")
|
||||
plot(p1, p2,
|
||||
layout=l,
|
||||
plot_title="t = $(i_str)",
|
||||
plot_title=latexstring("t = $(t_str)"),
|
||||
)
|
||||
xlims!(-2, 2)
|
||||
ylims!(-2, 2)
|
||||
end
|
||||
else
|
||||
sample, x_0, timestep = predictions[i]
|
||||
p1 = scatter(dataset[1, :], dataset[2, :],
|
||||
alpha=0.01,
|
||||
aspectratio=:equal,
|
||||
legend=false,
|
||||
title=L"x_t",
|
||||
)
|
||||
scatter!(sample[1, :], sample[2, :])
|
||||
|
||||
gif(anim, "sampling.gif", fps=30)
|
||||
p2 = scatter(dataset[1, :], dataset[2, :],
|
||||
alpha=0.01,
|
||||
aspectratio=:equal,
|
||||
legend=false,
|
||||
title=L"x_0",
|
||||
)
|
||||
scatter!(x_0[1, :], x_0[2, :])
|
||||
|
||||
l = @layout [a b]
|
||||
t_str = lpad(timestep - 1, 3, "0")
|
||||
plot(p1, p2,
|
||||
layout=l,
|
||||
plot_title=latexstring("t = $(t_str)"),
|
||||
)
|
||||
xlims!(-2, 2)
|
||||
ylims!(-2, 2)
|
||||
end
|
||||
end
|
||||
gif(anim, anim.dir * ".gif", fps=50)
|
||||
|
|
|
@ -87,12 +87,17 @@ function add_noise(
|
|||
ϵ::AbstractArray,
|
||||
t::AbstractArray,
|
||||
)
|
||||
⎷α̅ₜ = scheduler.⎷α̅[t]
|
||||
⎷β̅ₜ = scheduler.⎷β̅[t]
|
||||
# retreive scheduler variables at timesteps t
|
||||
reshape_size = tuple(
|
||||
fill(1, ndims(x₀) - 1)...,
|
||||
size(t, 1)
|
||||
)
|
||||
⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size)
|
||||
⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size)
|
||||
|
||||
# noisify clean data
|
||||
# arxiv:2006.11239 Eq. 4
|
||||
xₜ = ⎷α̅ₜ' .* x₀ + ⎷β̅ₜ' .* ϵ
|
||||
xₜ = ⎷α̅ₜ .* x₀ + ⎷β̅ₜ .* ϵ
|
||||
|
||||
return xₜ
|
||||
end
|
||||
|
@ -117,31 +122,35 @@ function step(
|
|||
t::AbstractArray,
|
||||
)
|
||||
# retreive scheduler variables at timesteps t
|
||||
βₜ = scheduler.β[t]
|
||||
β̅ₜ = scheduler.β̅[t]
|
||||
β̅ₜ₋₁ = scheduler.β̅₋₁[t]
|
||||
⎷αₜ = scheduler.⎷α[t]
|
||||
⎷α̅ₜ = scheduler.⎷α̅[t]
|
||||
⎷α̅ₜ₋₁ = scheduler.⎷α̅₋₁[t]
|
||||
⎷β̅ₜ = scheduler.⎷β̅[t]
|
||||
reshape_size = tuple(
|
||||
fill(1, ndims(xₜ) - 1)...,
|
||||
size(t, 1)
|
||||
)
|
||||
βₜ = reshape(scheduler.β[t], reshape_size)
|
||||
β̅ₜ = reshape(scheduler.β̅[t], reshape_size)
|
||||
β̅ₜ₋₁ = reshape(scheduler.β̅₋₁[t], reshape_size)
|
||||
⎷αₜ = reshape(scheduler.⎷α[t], reshape_size)
|
||||
⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size)
|
||||
⎷α̅ₜ₋₁ = reshape(scheduler.⎷α̅₋₁[t], reshape_size)
|
||||
⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size)
|
||||
|
||||
# compute predicted previous sample x̂₀
|
||||
# arxiv:2006.11239 Eq. 15
|
||||
# arxiv:2208.11970 Eq. 115
|
||||
x̂₀ = (xₜ - ⎷β̅ₜ' .* ϵᵧ) ./ ⎷α̅ₜ'
|
||||
x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ
|
||||
|
||||
# compute predicted previous sample μ̃ₜ
|
||||
# arxiv:2006.11239 Eq. 7
|
||||
# arxiv:2208.11970 Eq. 84
|
||||
λ₀ = ⎷α̅ₜ₋₁ .* βₜ ./ β̅ₜ
|
||||
λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ # TODO: this could be stored in the scheduler
|
||||
μ̃ₜ = λ₀' .* x̂₀ + λₜ' .* xₜ
|
||||
μ̃ₜ = λ₀ .* x̂₀ + λₜ .* xₜ
|
||||
|
||||
# sample predicted previous sample xₜ₋₁
|
||||
# arxiv:2006.11239 Eq. 6
|
||||
# arxiv:2208.11970 Eq. 70
|
||||
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler
|
||||
xₜ₋₁ = μ̃ₜ + σₜ' .* randn(size(ϵᵧ))
|
||||
xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ))
|
||||
|
||||
return xₜ₋₁, x̂₀
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue