mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-11-08 14:38:58 +00:00
🚚 rename a bunch of files
✨ add reverse process
This commit is contained in:
parent
03514a68d0
commit
48de7f0bce
|
@ -1,4 +1,5 @@
|
||||||
import Diffusers
|
import Diffusers
|
||||||
|
using Flux
|
||||||
using Random
|
using Random
|
||||||
using Plots
|
using Plots
|
||||||
|
|
||||||
|
@ -33,21 +34,21 @@ scatter(data[1, :], data[2, :],
|
||||||
aspectratio=:equal,
|
aspectratio=:equal,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_timesteps = 1000
|
num_timesteps = 100
|
||||||
scheduler = Diffusers.DDPM(
|
scheduler = Diffusers.DDPM(
|
||||||
Vector{Float64},
|
Vector{Float64},
|
||||||
Diffusers.cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0),
|
Diffusers.cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0),
|
||||||
)
|
)
|
||||||
|
|
||||||
noise = randn(size(X))
|
noise = randn(size(data))
|
||||||
|
|
||||||
anim = @animate for i in 1:num_timesteps
|
anim = @animate for i in cat(collect(1:num_timesteps), repeat([num_timesteps], 50), dims=1)
|
||||||
noisy_data = Diffusers.add_noise(scheduler, X, noise, [i])
|
noisy_data = Diffusers.add_noise(scheduler, data, noise, [i])
|
||||||
scatter(noise[1, :], noise[2, :],
|
scatter(noise[1, :], noise[2, :],
|
||||||
alpha=0.1,
|
alpha=0.1,
|
||||||
aspectratio=:equal,
|
aspectratio=:equal,
|
||||||
label="noise",
|
label="noise",
|
||||||
legend=:topright,
|
legend=:outertopright,
|
||||||
)
|
)
|
||||||
scatter!(noisy_data[1, :], noisy_data[2, :],
|
scatter!(noisy_data[1, :], noisy_data[2, :],
|
||||||
alpha=0.5,
|
alpha=0.5,
|
||||||
|
@ -66,3 +67,53 @@ anim = @animate for i in 1:num_timesteps
|
||||||
end
|
end
|
||||||
|
|
||||||
gif(anim, "swissroll.gif", fps=50)
|
gif(anim, "swissroll.gif", fps=50)
|
||||||
|
|
||||||
|
d_hid = 32
|
||||||
|
model = Diffusers.ConditionalChain(
|
||||||
|
Parallel(
|
||||||
|
.+,
|
||||||
|
Dense(2, d_hid),
|
||||||
|
Chain(
|
||||||
|
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
||||||
|
Dense(d_hid, d_hid))
|
||||||
|
),
|
||||||
|
relu,
|
||||||
|
Parallel(
|
||||||
|
.+,
|
||||||
|
Dense(d_hid, d_hid),
|
||||||
|
Chain(
|
||||||
|
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
||||||
|
Dense(d_hid, d_hid))
|
||||||
|
),
|
||||||
|
relu,
|
||||||
|
Parallel(
|
||||||
|
.+,
|
||||||
|
Dense(d_hid, d_hid),
|
||||||
|
Chain(
|
||||||
|
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
||||||
|
Dense(d_hid, d_hid))
|
||||||
|
),
|
||||||
|
relu,
|
||||||
|
Dense(d_hid, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
model(data, [100])
|
||||||
|
|
||||||
|
|
||||||
|
num_epochs = 10
|
||||||
|
loss = Flux.Losses.mse
|
||||||
|
dataloader = Flux.DataLoader(X |> to_device; batchsize=32, shuffle=true);
|
||||||
|
for epoch = 1:num_epochs
|
||||||
|
progress = Progress(length(data); desc="epoch $epoch/$num_epochs")
|
||||||
|
params = Flux.params(model)
|
||||||
|
for data in dataloader
|
||||||
|
grads = Flux.gradient(model) do m
|
||||||
|
model_output = m(data)
|
||||||
|
noise_prediction = Diffusers.step(model_output, timesteps, scheduler)
|
||||||
|
loss(noise, noise_prediction)
|
||||||
|
end
|
||||||
|
Flux.update!(opt, params, grads)
|
||||||
|
ProgressMeter.next!(progress; showvalues=[("batch loss", @sprintf("%.5f", batch_loss))])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
|
99
src/ConditionalChain.jl
Normal file
99
src/ConditionalChain.jl
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
using Flux
|
||||||
|
import Flux._show_children
|
||||||
|
import Flux._big_show
|
||||||
|
|
||||||
|
|
||||||
|
abstract type AbstractParallel end
|
||||||
|
|
||||||
|
_maybe_forward(layer::AbstractParallel, x::AbstractArray, ys::AbstractArray...) = layer(x, ys...)
|
||||||
|
_maybe_forward(layer::Parallel, x::AbstractArray, ys::AbstractArray...) = layer(x, ys...)
|
||||||
|
_maybe_forward(layer, x::AbstractArray, ys::AbstractArray...) = layer(x)
|
||||||
|
|
||||||
|
"""
|
||||||
|
ConditionalChain(layers...)
|
||||||
|
|
||||||
|
Based off `Flux.Chain` except takes in multiple inputs.
|
||||||
|
If a layer is of type `AbstractParallel` it uses all inputs else it uses only the first one.
|
||||||
|
The first input can therefore be conditioned on the other inputs.
|
||||||
|
"""
|
||||||
|
struct ConditionalChain{T<:Union{Tuple,NamedTuple}} <: AbstractParallel
|
||||||
|
layers::T
|
||||||
|
end
|
||||||
|
Flux.@functor ConditionalChain
|
||||||
|
|
||||||
|
ConditionalChain(xs...) = ConditionalChain(xs)
|
||||||
|
function ConditionalChain(; kw...)
|
||||||
|
:layers in keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
|
||||||
|
isempty(kw) && return ConditionalChain(())
|
||||||
|
ConditionalChain(values(kw))
|
||||||
|
end
|
||||||
|
|
||||||
|
Flux.@forward ConditionalChain.layers Base.getindex, Base.length, Base.first, Base.last,
|
||||||
|
Base.iterate, Base.lastindex, Base.keys, Base.firstindex
|
||||||
|
|
||||||
|
Base.getindex(c::ConditionalChain, i::AbstractArray) = ConditionalChain(c.layers[i]...)
|
||||||
|
|
||||||
|
function (c::ConditionalChain)(x, ys...)
|
||||||
|
for layer in c.layers
|
||||||
|
x = _maybe_forward(layer, x, ys...)
|
||||||
|
end
|
||||||
|
x
|
||||||
|
end
|
||||||
|
|
||||||
|
function Base.show(io::IO, c::ConditionalChain)
|
||||||
|
print(io, "ConditionalChain(")
|
||||||
|
Flux._show_layers(io, c.layers)
|
||||||
|
print(io, ")")
|
||||||
|
end
|
||||||
|
|
||||||
|
function _big_show(io::IO, m::ConditionalChain{T}, indent::Int=0, name=nothing) where {T<:NamedTuple}
|
||||||
|
println(io, " "^indent, isnothing(name) ? "" : "$name = ", "ConditionalChain(")
|
||||||
|
for k in Base.keys(m.layers)
|
||||||
|
_big_show(io, m.layers[k], indent + 2, k)
|
||||||
|
end
|
||||||
|
if indent == 0
|
||||||
|
print(io, ") ")
|
||||||
|
_big_finale(io, m)
|
||||||
|
else
|
||||||
|
println(io, " "^indent, ")", ",")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
ConditionalSkipConnection(layers, connection)
|
||||||
|
|
||||||
|
The output is equivalent to `connection(layers(x, ys...), x)`.
|
||||||
|
Based off Flux.SkipConnection except it passes multiple arguments to layers.
|
||||||
|
"""
|
||||||
|
struct ConditionalSkipConnection{T,F} <: AbstractParallel
|
||||||
|
layers::T
|
||||||
|
connection::F
|
||||||
|
end
|
||||||
|
|
||||||
|
Flux.@functor ConditionalSkipConnection
|
||||||
|
|
||||||
|
function (skip::ConditionalSkipConnection)(x, ys...)
|
||||||
|
skip.connection(skip.layers(x, ys...), x)
|
||||||
|
end
|
||||||
|
|
||||||
|
function Base.show(io::IO, b::ConditionalSkipConnection)
|
||||||
|
print(io, "ConditionalSkipConnection(", b.layers, ", ", b.connection, ")")
|
||||||
|
end
|
||||||
|
|
||||||
|
### Show. Copied from Flux.jl/src/layers/show.jl
|
||||||
|
|
||||||
|
for T in [
|
||||||
|
:ConditionalChain, ConditionalSkipConnection
|
||||||
|
]
|
||||||
|
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
|
||||||
|
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
|
||||||
|
Flux._big_show(io, x)
|
||||||
|
elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix
|
||||||
|
Flux._layer_show(io, x)
|
||||||
|
else
|
||||||
|
show(io, x)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
_show_children(c::ConditionalChain) = c.layers
|
|
@ -1,4 +1,4 @@
|
||||||
include("scheduler.jl")
|
include("Schedulers.jl")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
|
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
|
|
@ -1,7 +1,14 @@
|
||||||
module Diffusers
|
module Diffusers
|
||||||
|
|
||||||
include("scheduler.jl")
|
# utils
|
||||||
include("beta_scheduler.jl")
|
include("Embeddings.jl")
|
||||||
include("ddpm.jl")
|
include("ConditionalChain.jl")
|
||||||
|
|
||||||
|
# abtract types
|
||||||
|
include("Schedulers.jl")
|
||||||
|
include("BetaSchedulers.jl")
|
||||||
|
|
||||||
|
# concrete types
|
||||||
|
include("DDPM.jl")
|
||||||
|
|
||||||
end # module Diffusers
|
end # module Diffusers
|
||||||
|
|
33
src/Embeddings.jl
Normal file
33
src/Embeddings.jl
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
using Flux
|
||||||
|
|
||||||
|
struct SinusoidalPositionEmbedding{W<:AbstractArray}
|
||||||
|
weight::W
|
||||||
|
end
|
||||||
|
|
||||||
|
Flux.@functor SinusoidalPositionEmbedding
|
||||||
|
Flux.trainable(emb::SinusoidalPositionEmbedding) = () # mark it as an non-trainable array
|
||||||
|
|
||||||
|
function SinusoidalPositionEmbedding(in::Int, out::Int)
|
||||||
|
W = make_positional_embedding(out, in)
|
||||||
|
SinusoidalPositionEmbedding(W)
|
||||||
|
end
|
||||||
|
|
||||||
|
function make_positional_embedding(dim_embedding::Int, seq_length::Int=1000; n::Int=10000)
|
||||||
|
embedding = Matrix{Float32}(undef, dim_embedding, seq_length)
|
||||||
|
for pos in 1:seq_length
|
||||||
|
for row in 0:2:(dim_embedding-1)
|
||||||
|
denom = 1.0 / (n^(row / (dim_embedding - 2)))
|
||||||
|
embedding[row+1, pos] = sin(pos * denom)
|
||||||
|
embedding[row+2, pos] = cos(pos * denom)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
embedding
|
||||||
|
end
|
||||||
|
|
||||||
|
(m::SinusoidalPositionEmbedding)(x::Integer) = m.weight[:, x]
|
||||||
|
(m::SinusoidalPositionEmbedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
|
||||||
|
(m::SinusoidalPositionEmbedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
|
||||||
|
|
||||||
|
function Base.show(io::IO, m::SinusoidalPositionEmbedding)
|
||||||
|
print(io, "SinusoidalPositionEmbedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
|
||||||
|
end
|
72
src/Schedulers.jl
Normal file
72
src/Schedulers.jl
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
abstract type Scheduler end
|
||||||
|
|
||||||
|
function add_noise(
|
||||||
|
scheduler::Scheduler,
|
||||||
|
clean_data::AbstractArray,
|
||||||
|
noise::AbstractArray,
|
||||||
|
timesteps::AbstractArray,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
Add noise to clean data using the forward diffusion process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler (`Scheduler`): scheduler object.
|
||||||
|
clean_data (`AbstractArray`): clean data to add noise to.
|
||||||
|
noise (`AbstractArray`): noise to add to clean data.
|
||||||
|
timesteps (`AbstractArray`): timesteps used to weight the noise.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`AbstractArray`: noisy data at the given timesteps.
|
||||||
|
"""
|
||||||
|
sqrt_α_cumprod_t = scheduler.sqrt_α_cumprods[timesteps]
|
||||||
|
sqrt_one_minus_α_cumprod_t = scheduler.sqrt_one_minus_α_cumprods[timesteps]
|
||||||
|
|
||||||
|
sqrt_α_cumprod_t .* clean_data .+ sqrt_one_minus_α_cumprod_t .* noise
|
||||||
|
end
|
||||||
|
|
||||||
|
function step(
|
||||||
|
scheduler::Scheduler,
|
||||||
|
sample::AbstractArray,
|
||||||
|
model_output::AbstractArray,
|
||||||
|
timestep::Int,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
Remove noise from model output using the backward diffusion process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler (`Scheduler`): scheduler object.
|
||||||
|
sample (`AbstractArray`): sample to remove noise from, i.e. model_input.
|
||||||
|
model_output (`AbstractArray`): predicted noise from the model.
|
||||||
|
timestep (`Int`): timestep to remove noise from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`AbstractArray`: denoised model output at the given timestep.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 1. compute alphas, betas
|
||||||
|
α_cumprod_t = scheduler.α_cumprods[timestep]
|
||||||
|
α_cumprod_t_prev = scheduler.α_cumprods[timestep - 1]
|
||||||
|
β_cumprod_t = 1 - α_cumprod_t
|
||||||
|
β_cumprod_t_prev = 1 - α_cumprod_t_prev
|
||||||
|
current_α_t = α_cumprod_t / α_cumprod_t_prev
|
||||||
|
current_β_t = 1 - current_α_t
|
||||||
|
|
||||||
|
# 2. compute predicted original sample from predicted noise also called
|
||||||
|
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||||
|
# epsilon prediction type
|
||||||
|
x_0 = (noise - √β_cumprod_t_prev * model_output) / √α_cumprod_t_prev
|
||||||
|
|
||||||
|
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
||||||
|
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||||
|
pred_original_sample_coeff = (√α_cumprod_t_prev * current_β_t) / β_cumprod_t
|
||||||
|
current_sample_coeff = √current_α_t * β_cumprod_t_prev / β_cumprod_t
|
||||||
|
|
||||||
|
# 5. Compute predicted previous sample µ_t
|
||||||
|
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||||
|
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
|
||||||
|
|
||||||
|
# 6. Add noise
|
||||||
|
variance = √scheduler.βs[timestep] * randn(size(model_output))
|
||||||
|
pred_prev_sample = pred_prev_sample + variance
|
||||||
|
|
||||||
|
return pred_prev_sample, pred_original_sample
|
|
@ -1,14 +0,0 @@
|
||||||
abstract type Scheduler end
|
|
||||||
|
|
||||||
function add_noise(
|
|
||||||
scheduler::Scheduler,
|
|
||||||
original_samples::AbstractArray,
|
|
||||||
noise::AbstractArray,
|
|
||||||
timesteps::AbstractArray,
|
|
||||||
)
|
|
||||||
alphas_cumprod = scheduler.α_cumprods[timesteps]
|
|
||||||
sqrt_alpha_prod = sqrt.(alphas_cumprod)
|
|
||||||
sqrt_one_minus_alpha_prod = sqrt.(1 .- alphas_cumprod)
|
|
||||||
|
|
||||||
sqrt_alpha_prod .* original_samples .+ sqrt_one_minus_alpha_prod .* noise
|
|
||||||
end
|
|
Loading…
Reference in a new issue