diff --git a/Project.toml b/Project.toml index bba1cd2..896edbc 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,9 @@ authors = ["Laurent Fainsin "] version = "0.1.0" [deps] +DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" diff --git a/examples/ConditionalChain.jl b/examples/ConditionalChain.jl deleted file mode 100644 index 0c279ee..0000000 --- a/examples/ConditionalChain.jl +++ /dev/null @@ -1,99 +0,0 @@ -using Flux -import Flux._show_children -import Flux._big_show - - -abstract type AbstractParallel end - -_maybe_forward(layer::AbstractParallel, x::AbstractArray, ys::AbstractArray...) = layer(x, ys...) -_maybe_forward(layer::Parallel, x::AbstractArray, ys::AbstractArray...) = layer(x, ys...) -_maybe_forward(layer, x::AbstractArray, ys::AbstractArray...) = layer(x) - -""" - ConditionalChain(layers...) - -Based off `Flux.Chain` except takes in multiple inputs. -If a layer is of type `AbstractParallel` it uses all inputs else it uses only the first one. -The first input can therefore be conditioned on the other inputs. -""" -struct ConditionalChain{T<:Union{Tuple,NamedTuple}} <: AbstractParallel - layers::T -end -Flux.@functor ConditionalChain - -ConditionalChain(xs...) = ConditionalChain(xs) -function ConditionalChain(; kw...) - :layers in keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`")) - isempty(kw) && return ConditionalChain(()) - ConditionalChain(values(kw)) -end - -Flux.@forward ConditionalChain.layers Base.getindex, Base.length, Base.first, Base.last, -Base.iterate, Base.lastindex, Base.keys, Base.firstindex - -Base.getindex(c::ConditionalChain, i::AbstractArray) = ConditionalChain(c.layers[i]...) - -function (c::ConditionalChain)(x, ys...) - for layer in c.layers - x = _maybe_forward(layer, x, ys...) - end - x -end - -function Base.show(io::IO, c::ConditionalChain) - print(io, "ConditionalChain(") - Flux._show_layers(io, c.layers) - print(io, ")") -end - -function _big_show(io::IO, m::ConditionalChain{T}, indent::Integer=0, name=nothing) where {T<:NamedTuple} - println(io, " "^indent, isnothing(name) ? "" : "$name = ", "ConditionalChain(") - for k in Base.keys(m.layers) - _big_show(io, m.layers[k], indent + 2, k) - end - if indent == 0 - print(io, ") ") - _big_finale(io, m) - else - println(io, " "^indent, ")", ",") - end -end - -""" - ConditionalSkipConnection(layers, connection) - -The output is equivalent to `connection(layers(x, ys...), x)`. -Based off Flux.SkipConnection except it passes multiple arguments to layers. -""" -struct ConditionalSkipConnection{T,F} <: AbstractParallel - layers::T - connection::F -end - -Flux.@functor ConditionalSkipConnection - -function (skip::ConditionalSkipConnection)(x, ys...) - skip.connection(skip.layers(x, ys...), x) -end - -function Base.show(io::IO, b::ConditionalSkipConnection) - print(io, "ConditionalSkipConnection(", b.layers, ", ", b.connection, ")") -end - -### Show. Copied from Flux.jl/src/layers/show.jl - -for T in [ - :ConditionalChain, ConditionalSkipConnection -] - @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) - if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL - Flux._big_show(io, x) - elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix - Flux._layer_show(io, x) - else - show(io, x) - end - end -end - -_show_children(c::ConditionalChain) = c.layers diff --git a/examples/Embeddings.jl b/examples/Embeddings.jl deleted file mode 100644 index 746b8e6..0000000 --- a/examples/Embeddings.jl +++ /dev/null @@ -1,33 +0,0 @@ -using Flux - -struct SinusoidalPositionEmbedding{W<:AbstractArray} - weight::W -end - -Flux.@functor SinusoidalPositionEmbedding -Flux.trainable(emb::SinusoidalPositionEmbedding) = () # mark it as an non-trainable array - -function SinusoidalPositionEmbedding(in::Integer, out::Integer) - W = make_positional_embedding(out, in) - SinusoidalPositionEmbedding(W) -end - -function make_positional_embedding(dim_embedding::Integer, seq_length::Integer=1000; n::Integer=10000) - embedding = Matrix{Float32}(undef, dim_embedding, seq_length) - for pos in 1:seq_length - for row in 0:2:(dim_embedding-1) - denom = 1.0 / (n^(row / (dim_embedding - 2))) - embedding[row+1, pos] = sin(pos * denom) - embedding[row+2, pos] = cos(pos * denom) - end - end - embedding -end - -(m::SinusoidalPositionEmbedding)(x::Integer) = m.weight[:, x] -(m::SinusoidalPositionEmbedding)(x::AbstractVector) = NNlib.gather(m.weight, x) -(m::SinusoidalPositionEmbedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...) - -function Base.show(io::IO, m::SinusoidalPositionEmbedding) - print(io, "SinusoidalPositionEmbedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")") -end diff --git a/examples/swissroll.jl b/examples/swissroll.jl index 6de614f..1122949 100644 --- a/examples/swissroll.jl +++ b/examples/swissroll.jl @@ -6,16 +6,11 @@ using Flux using Random using Plots using ProgressMeter +using DenoisingDiffusion +using LaTeXStrings -# utils -include("Embeddings.jl") -include("ConditionalChain.jl") - -function make_spiral(rng::AbstractRNG, n_samples::Integer=1000) - t_min = 1.5π - t_max = 4.5π - - t = rand(rng, n_samples) * (t_max - t_min) .+ t_min +function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π) + t = rand(n_samples) * (t_max - t_min) .+ t_min x = t .* cos.(t) y = t .* sin.(t) @@ -23,8 +18,6 @@ function make_spiral(rng::AbstractRNG, n_samples::Integer=1000) permutedims([x y], (2, 1)) end -make_spiral(n_samples::Integer=1000) = make_spiral(Random.GLOBAL_RNG, n_samples) - function normalize_zero_to_one(x) x_min, x_max = extrema(x) x_norm = (x .- x_min) ./ (x_max - x_min) @@ -35,9 +28,11 @@ function normalize_neg_one_to_one(x) 2 * normalize_zero_to_one(x) .- 1 end -n_samples = 1000 -data = normalize_neg_one_to_one(make_spiral(n_samples)) -scatter(data[1, :], data[2, :], +# make a dataset of 100 spirals +n_points = 2500 +dataset = make_spiral(n_points, 1π, 5π) +dataset = normalize_neg_one_to_one(dataset) +scatter(dataset[1, :], dataset[2, :], alpha=0.5, aspectratio=:equal, ) @@ -46,37 +41,56 @@ num_timesteps = 100 scheduler = DDPM( Vector{Float64}, rescale_zero_terminal_snr( - cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0) + cosine_beta_schedule(num_timesteps) ) -) +); +data = dataset[:, 1:100] noise = randn(size(data)) -anim = @animate for i in cat(1:num_timesteps, repeat([num_timesteps], 50), dims=1) - noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, [i]) - scatter(noise[1, :], noise[2, :], - alpha=0.1, - aspectratio=:equal, - label="noise", - legend=:outertopright, - ) - scatter!(noisy_data[1, :], noisy_data[2, :], - alpha=0.5, - aspectratio=:equal, - label="noisy data", - ) - scatter!(data[1, :], data[2, :], - alpha=0.5, - aspectratio=:equal, - label="data", - ) - i_str = lpad(i, 3, "0") - title!("t = $(i_str)") - xlims!(-3, 3) - ylims!(-3, 3) +anim = @animate for t in cat(fill(0, 25), 1:num_timesteps, fill(num_timesteps, 50), dims=1) + if t == 0 + scatter(noise[1, :], noise[2, :], + alpha=0.3, + aspectratio=:equal, + label="noise", + legend=:outertopright, + ) + scatter!(data[1, :], data[2, :], + alpha=0.3, + aspectratio=:equal, + label="data", + ) + scatter!(data[1, :], data[2, :], + aspectratio=:equal, + label="noisy data", + ) + title!("t = " * lpad(t, 3, "0")) + xlims!(-3, 3) + ylims!(-3, 3) + else + noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, [t]) + scatter(noise[1, :], noise[2, :], + alpha=0.3, + aspectratio=:equal, + label="noise", + legend=:outertopright, + ) + scatter!(data[1, :], data[2, :], + alpha=0.3, + aspectratio=:equal, + label="data", + ) + scatter!(noisy_data[1, :], noisy_data[2, :], + aspectratio=:equal, + label="noisy data", + ) + title!(latexstring("t = " * lpad(t, 3, "0"))) + xlims!(-3, 3) + ylims!(-3, 3) + end end - -gif(anim, "swissroll.gif", fps=50) +gif(anim, anim.dir * ".gif", fps=50) d_hid = 32 model = ConditionalChain( @@ -109,16 +123,16 @@ model = ConditionalChain( model(data, [100]) -num_epochs = 100 -loss = Flux.Losses.mse -opt = Flux.setup(Adam(0.0001), model) -dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true); -progress = Progress(num_epochs; desc="training", showspeed=true) +num_epochs = 100; +loss = Flux.Losses.mse; +opt = Flux.setup(Adam(0.0001), model); +dataloader = Flux.DataLoader(dataset |> cpu; batchsize=32, shuffle=true); +progress = Progress(num_epochs; desc="training", showspeed=true); for epoch = 1:num_epochs params = Flux.params(model) for data in dataloader noise = randn(size(data)) - timesteps = rand(2:num_timesteps, size(data)[2]) # TODO: fix start at timestep=2, bruh + timesteps = rand(2:num_timesteps, size(data, ndims(data))) # TODO: fix start at timestep=2, bruh noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, timesteps) grads = Flux.gradient(model) do m model_output = m(noisy_data, timesteps) @@ -130,43 +144,68 @@ for epoch = 1:num_epochs ProgressMeter.next!(progress) end -# sampling animation -anim = @animate for timestep in num_timesteps:-1:2 +## sampling animation + +sample = randn(2, 100) +sample_old = sample +predictions = [] +anim = for timestep in num_timesteps:-1:1 model_output = model(data, [timestep]) - sampled_data, x0_pred = Diffusers.Schedulers.step(scheduler, data, model_output, [timestep]) - - p1 = scatter(sampled_data[1, :], sampled_data[2, :], - alpha=0.5, - aspectratio=:equal, - label="sampled data", - legend=false, - ) - scatter!(data[1, :], data[2, :], - alpha=0.5, - aspectratio=:equal, - label="data", - ) - - p2 = scatter(x0_pred[1, :], x0_pred[2, :], - alpha=0.5, - aspectratio=:equal, - label="sampled data", - legend=false, - ) - scatter!(data[1, :], data[2, :], - alpha=0.5, - aspectratio=:equal, - label="data", - ) - - l = @layout [a b] - i_str = lpad(timestep, 3, "0") - plot(p1, p2, - layout=l, - plot_title="t = $(i_str)", - ) - xlims!(-2, 2) - ylims!(-2, 2) + sample, x0_pred = Diffusers.Schedulers.step(scheduler, data, model_output, [timestep]) + push!(predictions, (sample, x0_pred, timestep)) end -gif(anim, "sampling.gif", fps=30) +anim = @animate for i in cat(fill(0, 50), 1:num_timesteps, fill(num_timesteps, 50), dims=1) + if i == 0 + p1 = scatter(dataset[1, :], dataset[2, :], + alpha=0.01, + aspectratio=:equal, + title=L"x_t", + legend=false, + ) + scatter!(sample_old[1, :], sample_old[2, :]) + + p2 = scatter(dataset[1, :], dataset[2, :], + alpha=0.01, + aspectratio=:equal, + title=L"x_0", + legend=false, + ) + + l = @layout [a b] + t_str = lpad(num_timesteps, 3, "0") + plot(p1, p2, + layout=l, + plot_title=latexstring("t = $(t_str)"), + ) + xlims!(-2, 2) + ylims!(-2, 2) + else + sample, x_0, timestep = predictions[i] + p1 = scatter(dataset[1, :], dataset[2, :], + alpha=0.01, + aspectratio=:equal, + legend=false, + title=L"x_t", + ) + scatter!(sample[1, :], sample[2, :]) + + p2 = scatter(dataset[1, :], dataset[2, :], + alpha=0.01, + aspectratio=:equal, + legend=false, + title=L"x_0", + ) + scatter!(x_0[1, :], x_0[2, :]) + + l = @layout [a b] + t_str = lpad(timestep - 1, 3, "0") + plot(p1, p2, + layout=l, + plot_title=latexstring("t = $(t_str)"), + ) + xlims!(-2, 2) + ylims!(-2, 2) + end +end +gif(anim, anim.dir * ".gif", fps=50) diff --git a/src/Schedulers/DDPM.jl b/src/Schedulers/DDPM.jl index bbc4ac6..9e14dad 100644 --- a/src/Schedulers/DDPM.jl +++ b/src/Schedulers/DDPM.jl @@ -87,12 +87,17 @@ function add_noise( ϵ::AbstractArray, t::AbstractArray, ) - ⎷α̅ₜ = scheduler.⎷α̅[t] - ⎷β̅ₜ = scheduler.⎷β̅[t] + # retreive scheduler variables at timesteps t + reshape_size = tuple( + fill(1, ndims(x₀) - 1)..., + size(t, 1) + ) + ⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size) + ⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size) # noisify clean data # arxiv:2006.11239 Eq. 4 - xₜ = ⎷α̅ₜ' .* x₀ + ⎷β̅ₜ' .* ϵ + xₜ = ⎷α̅ₜ .* x₀ + ⎷β̅ₜ .* ϵ return xₜ end @@ -117,31 +122,35 @@ function step( t::AbstractArray, ) # retreive scheduler variables at timesteps t - βₜ = scheduler.β[t] - β̅ₜ = scheduler.β̅[t] - β̅ₜ₋₁ = scheduler.β̅₋₁[t] - ⎷αₜ = scheduler.⎷α[t] - ⎷α̅ₜ = scheduler.⎷α̅[t] - ⎷α̅ₜ₋₁ = scheduler.⎷α̅₋₁[t] - ⎷β̅ₜ = scheduler.⎷β̅[t] + reshape_size = tuple( + fill(1, ndims(xₜ) - 1)..., + size(t, 1) + ) + βₜ = reshape(scheduler.β[t], reshape_size) + β̅ₜ = reshape(scheduler.β̅[t], reshape_size) + β̅ₜ₋₁ = reshape(scheduler.β̅₋₁[t], reshape_size) + ⎷αₜ = reshape(scheduler.⎷α[t], reshape_size) + ⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size) + ⎷α̅ₜ₋₁ = reshape(scheduler.⎷α̅₋₁[t], reshape_size) + ⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size) # compute predicted previous sample x̂₀ # arxiv:2006.11239 Eq. 15 # arxiv:2208.11970 Eq. 115 - x̂₀ = (xₜ - ⎷β̅ₜ' .* ϵᵧ) ./ ⎷α̅ₜ' + x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ # compute predicted previous sample μ̃ₜ # arxiv:2006.11239 Eq. 7 # arxiv:2208.11970 Eq. 84 λ₀ = ⎷α̅ₜ₋₁ .* βₜ ./ β̅ₜ λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ # TODO: this could be stored in the scheduler - μ̃ₜ = λ₀' .* x̂₀ + λₜ' .* xₜ + μ̃ₜ = λ₀ .* x̂₀ + λₜ .* xₜ # sample predicted previous sample xₜ₋₁ # arxiv:2006.11239 Eq. 6 # arxiv:2208.11970 Eq. 70 σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler - xₜ₋₁ = μ̃ₜ + σₜ' .* randn(size(ϵᵧ)) + xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ)) return xₜ₋₁, x̂₀ end