🚚 rename a bunch of files

 add reverse process
This commit is contained in:
Laureηt 2023-07-05 21:13:26 +02:00
parent 03514a68d0
commit 48de7f0bce
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
8 changed files with 271 additions and 23 deletions

View file

@ -1,4 +1,5 @@
import Diffusers
using Flux
using Random
using Plots
@ -33,21 +34,21 @@ scatter(data[1, :], data[2, :],
aspectratio=:equal,
)
num_timesteps = 1000
num_timesteps = 100
scheduler = Diffusers.DDPM(
Vector{Float64},
Diffusers.cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0),
)
noise = randn(size(X))
noise = randn(size(data))
anim = @animate for i in 1:num_timesteps
noisy_data = Diffusers.add_noise(scheduler, X, noise, [i])
anim = @animate for i in cat(collect(1:num_timesteps), repeat([num_timesteps], 50), dims=1)
noisy_data = Diffusers.add_noise(scheduler, data, noise, [i])
scatter(noise[1, :], noise[2, :],
alpha=0.1,
aspectratio=:equal,
label="noise",
legend=:topright,
legend=:outertopright,
)
scatter!(noisy_data[1, :], noisy_data[2, :],
alpha=0.5,
@ -66,3 +67,53 @@ anim = @animate for i in 1:num_timesteps
end
gif(anim, "swissroll.gif", fps=50)
d_hid = 32
model = Diffusers.ConditionalChain(
Parallel(
.+,
Dense(2, d_hid),
Chain(
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid),
Dense(d_hid, d_hid))
),
relu,
Parallel(
.+,
Dense(d_hid, d_hid),
Chain(
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid),
Dense(d_hid, d_hid))
),
relu,
Parallel(
.+,
Dense(d_hid, d_hid),
Chain(
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid),
Dense(d_hid, d_hid))
),
relu,
Dense(d_hid, 2),
)
model(data, [100])
num_epochs = 10
loss = Flux.Losses.mse
dataloader = Flux.DataLoader(X |> to_device; batchsize=32, shuffle=true);
for epoch = 1:num_epochs
progress = Progress(length(data); desc="epoch $epoch/$num_epochs")
params = Flux.params(model)
for data in dataloader
grads = Flux.gradient(model) do m
model_output = m(data)
noise_prediction = Diffusers.step(model_output, timesteps, scheduler)
loss(noise, noise_prediction)
end
Flux.update!(opt, params, grads)
ProgressMeter.next!(progress; showvalues=[("batch loss", @sprintf("%.5f", batch_loss))])
end
end

99
src/ConditionalChain.jl Normal file
View file

@ -0,0 +1,99 @@
using Flux
import Flux._show_children
import Flux._big_show
abstract type AbstractParallel end
_maybe_forward(layer::AbstractParallel, x::AbstractArray, ys::AbstractArray...) = layer(x, ys...)
_maybe_forward(layer::Parallel, x::AbstractArray, ys::AbstractArray...) = layer(x, ys...)
_maybe_forward(layer, x::AbstractArray, ys::AbstractArray...) = layer(x)
"""
ConditionalChain(layers...)
Based off `Flux.Chain` except takes in multiple inputs.
If a layer is of type `AbstractParallel` it uses all inputs else it uses only the first one.
The first input can therefore be conditioned on the other inputs.
"""
struct ConditionalChain{T<:Union{Tuple,NamedTuple}} <: AbstractParallel
layers::T
end
Flux.@functor ConditionalChain
ConditionalChain(xs...) = ConditionalChain(xs)
function ConditionalChain(; kw...)
:layers in keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
isempty(kw) && return ConditionalChain(())
ConditionalChain(values(kw))
end
Flux.@forward ConditionalChain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys, Base.firstindex
Base.getindex(c::ConditionalChain, i::AbstractArray) = ConditionalChain(c.layers[i]...)
function (c::ConditionalChain)(x, ys...)
for layer in c.layers
x = _maybe_forward(layer, x, ys...)
end
x
end
function Base.show(io::IO, c::ConditionalChain)
print(io, "ConditionalChain(")
Flux._show_layers(io, c.layers)
print(io, ")")
end
function _big_show(io::IO, m::ConditionalChain{T}, indent::Int=0, name=nothing) where {T<:NamedTuple}
println(io, " "^indent, isnothing(name) ? "" : "$name = ", "ConditionalChain(")
for k in Base.keys(m.layers)
_big_show(io, m.layers[k], indent + 2, k)
end
if indent == 0
print(io, ") ")
_big_finale(io, m)
else
println(io, " "^indent, ")", ",")
end
end
"""
ConditionalSkipConnection(layers, connection)
The output is equivalent to `connection(layers(x, ys...), x)`.
Based off Flux.SkipConnection except it passes multiple arguments to layers.
"""
struct ConditionalSkipConnection{T,F} <: AbstractParallel
layers::T
connection::F
end
Flux.@functor ConditionalSkipConnection
function (skip::ConditionalSkipConnection)(x, ys...)
skip.connection(skip.layers(x, ys...), x)
end
function Base.show(io::IO, b::ConditionalSkipConnection)
print(io, "ConditionalSkipConnection(", b.layers, ", ", b.connection, ")")
end
### Show. Copied from Flux.jl/src/layers/show.jl
for T in [
:ConditionalChain, ConditionalSkipConnection
]
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
Flux._big_show(io, x)
elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix
Flux._layer_show(io, x)
else
show(io, x)
end
end
end
_show_children(c::ConditionalChain) = c.layers

View file

@ -1,4 +1,4 @@
include("scheduler.jl")
include("Schedulers.jl")
"""
Denoising Diffusion Probabilistic Models (DDPM) scheduler.

View file

@ -1,7 +1,14 @@
module Diffusers
include("scheduler.jl")
include("beta_scheduler.jl")
include("ddpm.jl")
# utils
include("Embeddings.jl")
include("ConditionalChain.jl")
# abtract types
include("Schedulers.jl")
include("BetaSchedulers.jl")
# concrete types
include("DDPM.jl")
end # module Diffusers

33
src/Embeddings.jl Normal file
View file

@ -0,0 +1,33 @@
using Flux
struct SinusoidalPositionEmbedding{W<:AbstractArray}
weight::W
end
Flux.@functor SinusoidalPositionEmbedding
Flux.trainable(emb::SinusoidalPositionEmbedding) = () # mark it as an non-trainable array
function SinusoidalPositionEmbedding(in::Int, out::Int)
W = make_positional_embedding(out, in)
SinusoidalPositionEmbedding(W)
end
function make_positional_embedding(dim_embedding::Int, seq_length::Int=1000; n::Int=10000)
embedding = Matrix{Float32}(undef, dim_embedding, seq_length)
for pos in 1:seq_length
for row in 0:2:(dim_embedding-1)
denom = 1.0 / (n^(row / (dim_embedding - 2)))
embedding[row+1, pos] = sin(pos * denom)
embedding[row+2, pos] = cos(pos * denom)
end
end
embedding
end
(m::SinusoidalPositionEmbedding)(x::Integer) = m.weight[:, x]
(m::SinusoidalPositionEmbedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::SinusoidalPositionEmbedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
function Base.show(io::IO, m::SinusoidalPositionEmbedding)
print(io, "SinusoidalPositionEmbedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end

72
src/Schedulers.jl Normal file
View file

@ -0,0 +1,72 @@
abstract type Scheduler end
function add_noise(
scheduler::Scheduler,
clean_data::AbstractArray,
noise::AbstractArray,
timesteps::AbstractArray,
)
"""
Add noise to clean data using the forward diffusion process.
Args:
scheduler (`Scheduler`): scheduler object.
clean_data (`AbstractArray`): clean data to add noise to.
noise (`AbstractArray`): noise to add to clean data.
timesteps (`AbstractArray`): timesteps used to weight the noise.
Returns:
`AbstractArray`: noisy data at the given timesteps.
"""
sqrt_α_cumprod_t = scheduler.sqrt_α_cumprods[timesteps]
sqrt_one_minus_α_cumprod_t = scheduler.sqrt_one_minus_α_cumprods[timesteps]
sqrt_α_cumprod_t .* clean_data .+ sqrt_one_minus_α_cumprod_t .* noise
end
function step(
scheduler::Scheduler,
sample::AbstractArray,
model_output::AbstractArray,
timestep::Int,
)
"""
Remove noise from model output using the backward diffusion process.
Args:
scheduler (`Scheduler`): scheduler object.
sample (`AbstractArray`): sample to remove noise from, i.e. model_input.
model_output (`AbstractArray`): predicted noise from the model.
timestep (`Int`): timestep to remove noise from.
Returns:
`AbstractArray`: denoised model output at the given timestep.
"""
# 1. compute alphas, betas
α_cumprod_t = scheduler.α_cumprods[timestep]
α_cumprod_t_prev = scheduler.α_cumprods[timestep - 1]
β_cumprod_t = 1 - α_cumprod_t
β_cumprod_t_prev = 1 - α_cumprod_t_prev
current_α_t = α_cumprod_t / α_cumprod_t_prev
current_β_t = 1 - current_α_t
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
# epsilon prediction type
x_0 = (noise - β_cumprod_t_prev * model_output) / α_cumprod_t_prev
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (α_cumprod_t_prev * current_β_t) / β_cumprod_t
current_sample_coeff = current_α_t * β_cumprod_t_prev / β_cumprod_t
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
# 6. Add noise
variance = scheduler.βs[timestep] * randn(size(model_output))
pred_prev_sample = pred_prev_sample + variance
return pred_prev_sample, pred_original_sample

View file

@ -1,14 +0,0 @@
abstract type Scheduler end
function add_noise(
scheduler::Scheduler,
original_samples::AbstractArray,
noise::AbstractArray,
timesteps::AbstractArray,
)
alphas_cumprod = scheduler.α_cumprods[timesteps]
sqrt_alpha_prod = sqrt.(alphas_cumprod)
sqrt_one_minus_alpha_prod = sqrt.(1 .- alphas_cumprod)
sqrt_alpha_prod .* original_samples .+ sqrt_one_minus_alpha_prod .* noise
end