mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-09-18 18:45:28 +00:00
🚚 rename a bunch of files
✨ add reverse process
This commit is contained in:
parent
03514a68d0
commit
48de7f0bce
|
@ -1,4 +1,5 @@
|
|||
import Diffusers
|
||||
using Flux
|
||||
using Random
|
||||
using Plots
|
||||
|
||||
|
@ -33,21 +34,21 @@ scatter(data[1, :], data[2, :],
|
|||
aspectratio=:equal,
|
||||
)
|
||||
|
||||
num_timesteps = 1000
|
||||
num_timesteps = 100
|
||||
scheduler = Diffusers.DDPM(
|
||||
Vector{Float64},
|
||||
Diffusers.cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0),
|
||||
)
|
||||
|
||||
noise = randn(size(X))
|
||||
noise = randn(size(data))
|
||||
|
||||
anim = @animate for i in 1:num_timesteps
|
||||
noisy_data = Diffusers.add_noise(scheduler, X, noise, [i])
|
||||
anim = @animate for i in cat(collect(1:num_timesteps), repeat([num_timesteps], 50), dims=1)
|
||||
noisy_data = Diffusers.add_noise(scheduler, data, noise, [i])
|
||||
scatter(noise[1, :], noise[2, :],
|
||||
alpha=0.1,
|
||||
aspectratio=:equal,
|
||||
label="noise",
|
||||
legend=:topright,
|
||||
legend=:outertopright,
|
||||
)
|
||||
scatter!(noisy_data[1, :], noisy_data[2, :],
|
||||
alpha=0.5,
|
||||
|
@ -66,3 +67,53 @@ anim = @animate for i in 1:num_timesteps
|
|||
end
|
||||
|
||||
gif(anim, "swissroll.gif", fps=50)
|
||||
|
||||
d_hid = 32
|
||||
model = Diffusers.ConditionalChain(
|
||||
Parallel(
|
||||
.+,
|
||||
Dense(2, d_hid),
|
||||
Chain(
|
||||
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
||||
Dense(d_hid, d_hid))
|
||||
),
|
||||
relu,
|
||||
Parallel(
|
||||
.+,
|
||||
Dense(d_hid, d_hid),
|
||||
Chain(
|
||||
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
||||
Dense(d_hid, d_hid))
|
||||
),
|
||||
relu,
|
||||
Parallel(
|
||||
.+,
|
||||
Dense(d_hid, d_hid),
|
||||
Chain(
|
||||
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
||||
Dense(d_hid, d_hid))
|
||||
),
|
||||
relu,
|
||||
Dense(d_hid, 2),
|
||||
)
|
||||
|
||||
model(data, [100])
|
||||
|
||||
|
||||
num_epochs = 10
|
||||
loss = Flux.Losses.mse
|
||||
dataloader = Flux.DataLoader(X |> to_device; batchsize=32, shuffle=true);
|
||||
for epoch = 1:num_epochs
|
||||
progress = Progress(length(data); desc="epoch $epoch/$num_epochs")
|
||||
params = Flux.params(model)
|
||||
for data in dataloader
|
||||
grads = Flux.gradient(model) do m
|
||||
model_output = m(data)
|
||||
noise_prediction = Diffusers.step(model_output, timesteps, scheduler)
|
||||
loss(noise, noise_prediction)
|
||||
end
|
||||
Flux.update!(opt, params, grads)
|
||||
ProgressMeter.next!(progress; showvalues=[("batch loss", @sprintf("%.5f", batch_loss))])
|
||||
end
|
||||
end
|
||||
|
||||
|
|
99
src/ConditionalChain.jl
Normal file
99
src/ConditionalChain.jl
Normal file
|
@ -0,0 +1,99 @@
|
|||
using Flux
|
||||
import Flux._show_children
|
||||
import Flux._big_show
|
||||
|
||||
|
||||
abstract type AbstractParallel end
|
||||
|
||||
_maybe_forward(layer::AbstractParallel, x::AbstractArray, ys::AbstractArray...) = layer(x, ys...)
|
||||
_maybe_forward(layer::Parallel, x::AbstractArray, ys::AbstractArray...) = layer(x, ys...)
|
||||
_maybe_forward(layer, x::AbstractArray, ys::AbstractArray...) = layer(x)
|
||||
|
||||
"""
|
||||
ConditionalChain(layers...)
|
||||
|
||||
Based off `Flux.Chain` except takes in multiple inputs.
|
||||
If a layer is of type `AbstractParallel` it uses all inputs else it uses only the first one.
|
||||
The first input can therefore be conditioned on the other inputs.
|
||||
"""
|
||||
struct ConditionalChain{T<:Union{Tuple,NamedTuple}} <: AbstractParallel
|
||||
layers::T
|
||||
end
|
||||
Flux.@functor ConditionalChain
|
||||
|
||||
ConditionalChain(xs...) = ConditionalChain(xs)
|
||||
function ConditionalChain(; kw...)
|
||||
:layers in keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
|
||||
isempty(kw) && return ConditionalChain(())
|
||||
ConditionalChain(values(kw))
|
||||
end
|
||||
|
||||
Flux.@forward ConditionalChain.layers Base.getindex, Base.length, Base.first, Base.last,
|
||||
Base.iterate, Base.lastindex, Base.keys, Base.firstindex
|
||||
|
||||
Base.getindex(c::ConditionalChain, i::AbstractArray) = ConditionalChain(c.layers[i]...)
|
||||
|
||||
function (c::ConditionalChain)(x, ys...)
|
||||
for layer in c.layers
|
||||
x = _maybe_forward(layer, x, ys...)
|
||||
end
|
||||
x
|
||||
end
|
||||
|
||||
function Base.show(io::IO, c::ConditionalChain)
|
||||
print(io, "ConditionalChain(")
|
||||
Flux._show_layers(io, c.layers)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
function _big_show(io::IO, m::ConditionalChain{T}, indent::Int=0, name=nothing) where {T<:NamedTuple}
|
||||
println(io, " "^indent, isnothing(name) ? "" : "$name = ", "ConditionalChain(")
|
||||
for k in Base.keys(m.layers)
|
||||
_big_show(io, m.layers[k], indent + 2, k)
|
||||
end
|
||||
if indent == 0
|
||||
print(io, ") ")
|
||||
_big_finale(io, m)
|
||||
else
|
||||
println(io, " "^indent, ")", ",")
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
ConditionalSkipConnection(layers, connection)
|
||||
|
||||
The output is equivalent to `connection(layers(x, ys...), x)`.
|
||||
Based off Flux.SkipConnection except it passes multiple arguments to layers.
|
||||
"""
|
||||
struct ConditionalSkipConnection{T,F} <: AbstractParallel
|
||||
layers::T
|
||||
connection::F
|
||||
end
|
||||
|
||||
Flux.@functor ConditionalSkipConnection
|
||||
|
||||
function (skip::ConditionalSkipConnection)(x, ys...)
|
||||
skip.connection(skip.layers(x, ys...), x)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, b::ConditionalSkipConnection)
|
||||
print(io, "ConditionalSkipConnection(", b.layers, ", ", b.connection, ")")
|
||||
end
|
||||
|
||||
### Show. Copied from Flux.jl/src/layers/show.jl
|
||||
|
||||
for T in [
|
||||
:ConditionalChain, ConditionalSkipConnection
|
||||
]
|
||||
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
|
||||
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
|
||||
Flux._big_show(io, x)
|
||||
elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix
|
||||
Flux._layer_show(io, x)
|
||||
else
|
||||
show(io, x)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
_show_children(c::ConditionalChain) = c.layers
|
|
@ -1,4 +1,4 @@
|
|||
include("scheduler.jl")
|
||||
include("Schedulers.jl")
|
||||
|
||||
"""
|
||||
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
|
|
@ -1,7 +1,14 @@
|
|||
module Diffusers
|
||||
|
||||
include("scheduler.jl")
|
||||
include("beta_scheduler.jl")
|
||||
include("ddpm.jl")
|
||||
# utils
|
||||
include("Embeddings.jl")
|
||||
include("ConditionalChain.jl")
|
||||
|
||||
# abtract types
|
||||
include("Schedulers.jl")
|
||||
include("BetaSchedulers.jl")
|
||||
|
||||
# concrete types
|
||||
include("DDPM.jl")
|
||||
|
||||
end # module Diffusers
|
||||
|
|
33
src/Embeddings.jl
Normal file
33
src/Embeddings.jl
Normal file
|
@ -0,0 +1,33 @@
|
|||
using Flux
|
||||
|
||||
struct SinusoidalPositionEmbedding{W<:AbstractArray}
|
||||
weight::W
|
||||
end
|
||||
|
||||
Flux.@functor SinusoidalPositionEmbedding
|
||||
Flux.trainable(emb::SinusoidalPositionEmbedding) = () # mark it as an non-trainable array
|
||||
|
||||
function SinusoidalPositionEmbedding(in::Int, out::Int)
|
||||
W = make_positional_embedding(out, in)
|
||||
SinusoidalPositionEmbedding(W)
|
||||
end
|
||||
|
||||
function make_positional_embedding(dim_embedding::Int, seq_length::Int=1000; n::Int=10000)
|
||||
embedding = Matrix{Float32}(undef, dim_embedding, seq_length)
|
||||
for pos in 1:seq_length
|
||||
for row in 0:2:(dim_embedding-1)
|
||||
denom = 1.0 / (n^(row / (dim_embedding - 2)))
|
||||
embedding[row+1, pos] = sin(pos * denom)
|
||||
embedding[row+2, pos] = cos(pos * denom)
|
||||
end
|
||||
end
|
||||
embedding
|
||||
end
|
||||
|
||||
(m::SinusoidalPositionEmbedding)(x::Integer) = m.weight[:, x]
|
||||
(m::SinusoidalPositionEmbedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
|
||||
(m::SinusoidalPositionEmbedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
|
||||
|
||||
function Base.show(io::IO, m::SinusoidalPositionEmbedding)
|
||||
print(io, "SinusoidalPositionEmbedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
|
||||
end
|
72
src/Schedulers.jl
Normal file
72
src/Schedulers.jl
Normal file
|
@ -0,0 +1,72 @@
|
|||
abstract type Scheduler end
|
||||
|
||||
function add_noise(
|
||||
scheduler::Scheduler,
|
||||
clean_data::AbstractArray,
|
||||
noise::AbstractArray,
|
||||
timesteps::AbstractArray,
|
||||
)
|
||||
"""
|
||||
Add noise to clean data using the forward diffusion process.
|
||||
|
||||
Args:
|
||||
scheduler (`Scheduler`): scheduler object.
|
||||
clean_data (`AbstractArray`): clean data to add noise to.
|
||||
noise (`AbstractArray`): noise to add to clean data.
|
||||
timesteps (`AbstractArray`): timesteps used to weight the noise.
|
||||
|
||||
Returns:
|
||||
`AbstractArray`: noisy data at the given timesteps.
|
||||
"""
|
||||
sqrt_α_cumprod_t = scheduler.sqrt_α_cumprods[timesteps]
|
||||
sqrt_one_minus_α_cumprod_t = scheduler.sqrt_one_minus_α_cumprods[timesteps]
|
||||
|
||||
sqrt_α_cumprod_t .* clean_data .+ sqrt_one_minus_α_cumprod_t .* noise
|
||||
end
|
||||
|
||||
function step(
|
||||
scheduler::Scheduler,
|
||||
sample::AbstractArray,
|
||||
model_output::AbstractArray,
|
||||
timestep::Int,
|
||||
)
|
||||
"""
|
||||
Remove noise from model output using the backward diffusion process.
|
||||
|
||||
Args:
|
||||
scheduler (`Scheduler`): scheduler object.
|
||||
sample (`AbstractArray`): sample to remove noise from, i.e. model_input.
|
||||
model_output (`AbstractArray`): predicted noise from the model.
|
||||
timestep (`Int`): timestep to remove noise from.
|
||||
|
||||
Returns:
|
||||
`AbstractArray`: denoised model output at the given timestep.
|
||||
"""
|
||||
|
||||
# 1. compute alphas, betas
|
||||
α_cumprod_t = scheduler.α_cumprods[timestep]
|
||||
α_cumprod_t_prev = scheduler.α_cumprods[timestep - 1]
|
||||
β_cumprod_t = 1 - α_cumprod_t
|
||||
β_cumprod_t_prev = 1 - α_cumprod_t_prev
|
||||
current_α_t = α_cumprod_t / α_cumprod_t_prev
|
||||
current_β_t = 1 - current_α_t
|
||||
|
||||
# 2. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
# epsilon prediction type
|
||||
x_0 = (noise - √β_cumprod_t_prev * model_output) / √α_cumprod_t_prev
|
||||
|
||||
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_sample_coeff = (√α_cumprod_t_prev * current_β_t) / β_cumprod_t
|
||||
current_sample_coeff = √current_α_t * β_cumprod_t_prev / β_cumprod_t
|
||||
|
||||
# 5. Compute predicted previous sample µ_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
|
||||
|
||||
# 6. Add noise
|
||||
variance = √scheduler.βs[timestep] * randn(size(model_output))
|
||||
pred_prev_sample = pred_prev_sample + variance
|
||||
|
||||
return pred_prev_sample, pred_original_sample
|
|
@ -1,14 +0,0 @@
|
|||
abstract type Scheduler end
|
||||
|
||||
function add_noise(
|
||||
scheduler::Scheduler,
|
||||
original_samples::AbstractArray,
|
||||
noise::AbstractArray,
|
||||
timesteps::AbstractArray,
|
||||
)
|
||||
alphas_cumprod = scheduler.α_cumprods[timesteps]
|
||||
sqrt_alpha_prod = sqrt.(alphas_cumprod)
|
||||
sqrt_one_minus_alpha_prod = sqrt.(1 .- alphas_cumprod)
|
||||
|
||||
sqrt_alpha_prod .* original_samples .+ sqrt_one_minus_alpha_prod .* noise
|
||||
end
|
Loading…
Reference in a new issue