mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-09-18 18:45:28 +00:00
♻️ (examples) rework swissroll
➕ add DenoisingDiffusion.jl
This commit is contained in:
parent
9ac0b53005
commit
9d5201d068
|
@ -4,7 +4,9 @@ authors = ["Laurent Fainsin <laurent@fainsin.bzh>"]
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
||||||
[deps]
|
[deps]
|
||||||
|
DenoisingDiffusion = "32e9e46b-ad0f-4c80-b32d-4f6f824844ef"
|
||||||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||||
|
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
|
||||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
|
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
|
||||||
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
|
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
|
||||||
|
|
|
@ -1,99 +0,0 @@
|
||||||
using Flux
|
|
||||||
import Flux._show_children
|
|
||||||
import Flux._big_show
|
|
||||||
|
|
||||||
|
|
||||||
abstract type AbstractParallel end
|
|
||||||
|
|
||||||
_maybe_forward(layer::AbstractParallel, x::AbstractArray, ys::AbstractArray...) = layer(x, ys...)
|
|
||||||
_maybe_forward(layer::Parallel, x::AbstractArray, ys::AbstractArray...) = layer(x, ys...)
|
|
||||||
_maybe_forward(layer, x::AbstractArray, ys::AbstractArray...) = layer(x)
|
|
||||||
|
|
||||||
"""
|
|
||||||
ConditionalChain(layers...)
|
|
||||||
|
|
||||||
Based off `Flux.Chain` except takes in multiple inputs.
|
|
||||||
If a layer is of type `AbstractParallel` it uses all inputs else it uses only the first one.
|
|
||||||
The first input can therefore be conditioned on the other inputs.
|
|
||||||
"""
|
|
||||||
struct ConditionalChain{T<:Union{Tuple,NamedTuple}} <: AbstractParallel
|
|
||||||
layers::T
|
|
||||||
end
|
|
||||||
Flux.@functor ConditionalChain
|
|
||||||
|
|
||||||
ConditionalChain(xs...) = ConditionalChain(xs)
|
|
||||||
function ConditionalChain(; kw...)
|
|
||||||
:layers in keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
|
|
||||||
isempty(kw) && return ConditionalChain(())
|
|
||||||
ConditionalChain(values(kw))
|
|
||||||
end
|
|
||||||
|
|
||||||
Flux.@forward ConditionalChain.layers Base.getindex, Base.length, Base.first, Base.last,
|
|
||||||
Base.iterate, Base.lastindex, Base.keys, Base.firstindex
|
|
||||||
|
|
||||||
Base.getindex(c::ConditionalChain, i::AbstractArray) = ConditionalChain(c.layers[i]...)
|
|
||||||
|
|
||||||
function (c::ConditionalChain)(x, ys...)
|
|
||||||
for layer in c.layers
|
|
||||||
x = _maybe_forward(layer, x, ys...)
|
|
||||||
end
|
|
||||||
x
|
|
||||||
end
|
|
||||||
|
|
||||||
function Base.show(io::IO, c::ConditionalChain)
|
|
||||||
print(io, "ConditionalChain(")
|
|
||||||
Flux._show_layers(io, c.layers)
|
|
||||||
print(io, ")")
|
|
||||||
end
|
|
||||||
|
|
||||||
function _big_show(io::IO, m::ConditionalChain{T}, indent::Integer=0, name=nothing) where {T<:NamedTuple}
|
|
||||||
println(io, " "^indent, isnothing(name) ? "" : "$name = ", "ConditionalChain(")
|
|
||||||
for k in Base.keys(m.layers)
|
|
||||||
_big_show(io, m.layers[k], indent + 2, k)
|
|
||||||
end
|
|
||||||
if indent == 0
|
|
||||||
print(io, ") ")
|
|
||||||
_big_finale(io, m)
|
|
||||||
else
|
|
||||||
println(io, " "^indent, ")", ",")
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
|
||||||
ConditionalSkipConnection(layers, connection)
|
|
||||||
|
|
||||||
The output is equivalent to `connection(layers(x, ys...), x)`.
|
|
||||||
Based off Flux.SkipConnection except it passes multiple arguments to layers.
|
|
||||||
"""
|
|
||||||
struct ConditionalSkipConnection{T,F} <: AbstractParallel
|
|
||||||
layers::T
|
|
||||||
connection::F
|
|
||||||
end
|
|
||||||
|
|
||||||
Flux.@functor ConditionalSkipConnection
|
|
||||||
|
|
||||||
function (skip::ConditionalSkipConnection)(x, ys...)
|
|
||||||
skip.connection(skip.layers(x, ys...), x)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Base.show(io::IO, b::ConditionalSkipConnection)
|
|
||||||
print(io, "ConditionalSkipConnection(", b.layers, ", ", b.connection, ")")
|
|
||||||
end
|
|
||||||
|
|
||||||
### Show. Copied from Flux.jl/src/layers/show.jl
|
|
||||||
|
|
||||||
for T in [
|
|
||||||
:ConditionalChain, ConditionalSkipConnection
|
|
||||||
]
|
|
||||||
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
|
|
||||||
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
|
|
||||||
Flux._big_show(io, x)
|
|
||||||
elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix
|
|
||||||
Flux._layer_show(io, x)
|
|
||||||
else
|
|
||||||
show(io, x)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
_show_children(c::ConditionalChain) = c.layers
|
|
|
@ -1,33 +0,0 @@
|
||||||
using Flux
|
|
||||||
|
|
||||||
struct SinusoidalPositionEmbedding{W<:AbstractArray}
|
|
||||||
weight::W
|
|
||||||
end
|
|
||||||
|
|
||||||
Flux.@functor SinusoidalPositionEmbedding
|
|
||||||
Flux.trainable(emb::SinusoidalPositionEmbedding) = () # mark it as an non-trainable array
|
|
||||||
|
|
||||||
function SinusoidalPositionEmbedding(in::Integer, out::Integer)
|
|
||||||
W = make_positional_embedding(out, in)
|
|
||||||
SinusoidalPositionEmbedding(W)
|
|
||||||
end
|
|
||||||
|
|
||||||
function make_positional_embedding(dim_embedding::Integer, seq_length::Integer=1000; n::Integer=10000)
|
|
||||||
embedding = Matrix{Float32}(undef, dim_embedding, seq_length)
|
|
||||||
for pos in 1:seq_length
|
|
||||||
for row in 0:2:(dim_embedding-1)
|
|
||||||
denom = 1.0 / (n^(row / (dim_embedding - 2)))
|
|
||||||
embedding[row+1, pos] = sin(pos * denom)
|
|
||||||
embedding[row+2, pos] = cos(pos * denom)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
embedding
|
|
||||||
end
|
|
||||||
|
|
||||||
(m::SinusoidalPositionEmbedding)(x::Integer) = m.weight[:, x]
|
|
||||||
(m::SinusoidalPositionEmbedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
|
|
||||||
(m::SinusoidalPositionEmbedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
|
|
||||||
|
|
||||||
function Base.show(io::IO, m::SinusoidalPositionEmbedding)
|
|
||||||
print(io, "SinusoidalPositionEmbedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
|
|
||||||
end
|
|
|
@ -6,16 +6,11 @@ using Flux
|
||||||
using Random
|
using Random
|
||||||
using Plots
|
using Plots
|
||||||
using ProgressMeter
|
using ProgressMeter
|
||||||
|
using DenoisingDiffusion
|
||||||
|
using LaTeXStrings
|
||||||
|
|
||||||
# utils
|
function make_spiral(n_samples::Integer=1000, t_min::Real=1.5π, t_max::Real=4.5π)
|
||||||
include("Embeddings.jl")
|
t = rand(n_samples) * (t_max - t_min) .+ t_min
|
||||||
include("ConditionalChain.jl")
|
|
||||||
|
|
||||||
function make_spiral(rng::AbstractRNG, n_samples::Integer=1000)
|
|
||||||
t_min = 1.5π
|
|
||||||
t_max = 4.5π
|
|
||||||
|
|
||||||
t = rand(rng, n_samples) * (t_max - t_min) .+ t_min
|
|
||||||
|
|
||||||
x = t .* cos.(t)
|
x = t .* cos.(t)
|
||||||
y = t .* sin.(t)
|
y = t .* sin.(t)
|
||||||
|
@ -23,8 +18,6 @@ function make_spiral(rng::AbstractRNG, n_samples::Integer=1000)
|
||||||
permutedims([x y], (2, 1))
|
permutedims([x y], (2, 1))
|
||||||
end
|
end
|
||||||
|
|
||||||
make_spiral(n_samples::Integer=1000) = make_spiral(Random.GLOBAL_RNG, n_samples)
|
|
||||||
|
|
||||||
function normalize_zero_to_one(x)
|
function normalize_zero_to_one(x)
|
||||||
x_min, x_max = extrema(x)
|
x_min, x_max = extrema(x)
|
||||||
x_norm = (x .- x_min) ./ (x_max - x_min)
|
x_norm = (x .- x_min) ./ (x_max - x_min)
|
||||||
|
@ -35,9 +28,11 @@ function normalize_neg_one_to_one(x)
|
||||||
2 * normalize_zero_to_one(x) .- 1
|
2 * normalize_zero_to_one(x) .- 1
|
||||||
end
|
end
|
||||||
|
|
||||||
n_samples = 1000
|
# make a dataset of 100 spirals
|
||||||
data = normalize_neg_one_to_one(make_spiral(n_samples))
|
n_points = 2500
|
||||||
scatter(data[1, :], data[2, :],
|
dataset = make_spiral(n_points, 1π, 5π)
|
||||||
|
dataset = normalize_neg_one_to_one(dataset)
|
||||||
|
scatter(dataset[1, :], dataset[2, :],
|
||||||
alpha=0.5,
|
alpha=0.5,
|
||||||
aspectratio=:equal,
|
aspectratio=:equal,
|
||||||
)
|
)
|
||||||
|
@ -46,37 +41,56 @@ num_timesteps = 100
|
||||||
scheduler = DDPM(
|
scheduler = DDPM(
|
||||||
Vector{Float64},
|
Vector{Float64},
|
||||||
rescale_zero_terminal_snr(
|
rescale_zero_terminal_snr(
|
||||||
cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0)
|
cosine_beta_schedule(num_timesteps)
|
||||||
)
|
)
|
||||||
)
|
);
|
||||||
|
|
||||||
|
data = dataset[:, 1:100]
|
||||||
noise = randn(size(data))
|
noise = randn(size(data))
|
||||||
|
|
||||||
anim = @animate for i in cat(1:num_timesteps, repeat([num_timesteps], 50), dims=1)
|
anim = @animate for t in cat(fill(0, 25), 1:num_timesteps, fill(num_timesteps, 50), dims=1)
|
||||||
noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, [i])
|
if t == 0
|
||||||
scatter(noise[1, :], noise[2, :],
|
scatter(noise[1, :], noise[2, :],
|
||||||
alpha=0.1,
|
alpha=0.3,
|
||||||
aspectratio=:equal,
|
aspectratio=:equal,
|
||||||
label="noise",
|
label="noise",
|
||||||
legend=:outertopright,
|
legend=:outertopright,
|
||||||
)
|
)
|
||||||
scatter!(noisy_data[1, :], noisy_data[2, :],
|
scatter!(data[1, :], data[2, :],
|
||||||
alpha=0.5,
|
alpha=0.3,
|
||||||
aspectratio=:equal,
|
aspectratio=:equal,
|
||||||
label="noisy data",
|
label="data",
|
||||||
)
|
)
|
||||||
scatter!(data[1, :], data[2, :],
|
scatter!(data[1, :], data[2, :],
|
||||||
alpha=0.5,
|
aspectratio=:equal,
|
||||||
aspectratio=:equal,
|
label="noisy data",
|
||||||
label="data",
|
)
|
||||||
)
|
title!("t = " * lpad(t, 3, "0"))
|
||||||
i_str = lpad(i, 3, "0")
|
xlims!(-3, 3)
|
||||||
title!("t = $(i_str)")
|
ylims!(-3, 3)
|
||||||
xlims!(-3, 3)
|
else
|
||||||
ylims!(-3, 3)
|
noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, [t])
|
||||||
|
scatter(noise[1, :], noise[2, :],
|
||||||
|
alpha=0.3,
|
||||||
|
aspectratio=:equal,
|
||||||
|
label="noise",
|
||||||
|
legend=:outertopright,
|
||||||
|
)
|
||||||
|
scatter!(data[1, :], data[2, :],
|
||||||
|
alpha=0.3,
|
||||||
|
aspectratio=:equal,
|
||||||
|
label="data",
|
||||||
|
)
|
||||||
|
scatter!(noisy_data[1, :], noisy_data[2, :],
|
||||||
|
aspectratio=:equal,
|
||||||
|
label="noisy data",
|
||||||
|
)
|
||||||
|
title!(latexstring("t = " * lpad(t, 3, "0")))
|
||||||
|
xlims!(-3, 3)
|
||||||
|
ylims!(-3, 3)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
gif(anim, anim.dir * ".gif", fps=50)
|
||||||
gif(anim, "swissroll.gif", fps=50)
|
|
||||||
|
|
||||||
d_hid = 32
|
d_hid = 32
|
||||||
model = ConditionalChain(
|
model = ConditionalChain(
|
||||||
|
@ -109,16 +123,16 @@ model = ConditionalChain(
|
||||||
|
|
||||||
model(data, [100])
|
model(data, [100])
|
||||||
|
|
||||||
num_epochs = 100
|
num_epochs = 100;
|
||||||
loss = Flux.Losses.mse
|
loss = Flux.Losses.mse;
|
||||||
opt = Flux.setup(Adam(0.0001), model)
|
opt = Flux.setup(Adam(0.0001), model);
|
||||||
dataloader = Flux.DataLoader(data |> cpu; batchsize=32, shuffle=true);
|
dataloader = Flux.DataLoader(dataset |> cpu; batchsize=32, shuffle=true);
|
||||||
progress = Progress(num_epochs; desc="training", showspeed=true)
|
progress = Progress(num_epochs; desc="training", showspeed=true);
|
||||||
for epoch = 1:num_epochs
|
for epoch = 1:num_epochs
|
||||||
params = Flux.params(model)
|
params = Flux.params(model)
|
||||||
for data in dataloader
|
for data in dataloader
|
||||||
noise = randn(size(data))
|
noise = randn(size(data))
|
||||||
timesteps = rand(2:num_timesteps, size(data)[2]) # TODO: fix start at timestep=2, bruh
|
timesteps = rand(2:num_timesteps, size(data, ndims(data))) # TODO: fix start at timestep=2, bruh
|
||||||
noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, timesteps)
|
noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, timesteps)
|
||||||
grads = Flux.gradient(model) do m
|
grads = Flux.gradient(model) do m
|
||||||
model_output = m(noisy_data, timesteps)
|
model_output = m(noisy_data, timesteps)
|
||||||
|
@ -130,43 +144,68 @@ for epoch = 1:num_epochs
|
||||||
ProgressMeter.next!(progress)
|
ProgressMeter.next!(progress)
|
||||||
end
|
end
|
||||||
|
|
||||||
# sampling animation
|
## sampling animation
|
||||||
anim = @animate for timestep in num_timesteps:-1:2
|
|
||||||
|
sample = randn(2, 100)
|
||||||
|
sample_old = sample
|
||||||
|
predictions = []
|
||||||
|
anim = for timestep in num_timesteps:-1:1
|
||||||
model_output = model(data, [timestep])
|
model_output = model(data, [timestep])
|
||||||
sampled_data, x0_pred = Diffusers.Schedulers.step(scheduler, data, model_output, [timestep])
|
sample, x0_pred = Diffusers.Schedulers.step(scheduler, data, model_output, [timestep])
|
||||||
|
push!(predictions, (sample, x0_pred, timestep))
|
||||||
p1 = scatter(sampled_data[1, :], sampled_data[2, :],
|
|
||||||
alpha=0.5,
|
|
||||||
aspectratio=:equal,
|
|
||||||
label="sampled data",
|
|
||||||
legend=false,
|
|
||||||
)
|
|
||||||
scatter!(data[1, :], data[2, :],
|
|
||||||
alpha=0.5,
|
|
||||||
aspectratio=:equal,
|
|
||||||
label="data",
|
|
||||||
)
|
|
||||||
|
|
||||||
p2 = scatter(x0_pred[1, :], x0_pred[2, :],
|
|
||||||
alpha=0.5,
|
|
||||||
aspectratio=:equal,
|
|
||||||
label="sampled data",
|
|
||||||
legend=false,
|
|
||||||
)
|
|
||||||
scatter!(data[1, :], data[2, :],
|
|
||||||
alpha=0.5,
|
|
||||||
aspectratio=:equal,
|
|
||||||
label="data",
|
|
||||||
)
|
|
||||||
|
|
||||||
l = @layout [a b]
|
|
||||||
i_str = lpad(timestep, 3, "0")
|
|
||||||
plot(p1, p2,
|
|
||||||
layout=l,
|
|
||||||
plot_title="t = $(i_str)",
|
|
||||||
)
|
|
||||||
xlims!(-2, 2)
|
|
||||||
ylims!(-2, 2)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
gif(anim, "sampling.gif", fps=30)
|
anim = @animate for i in cat(fill(0, 50), 1:num_timesteps, fill(num_timesteps, 50), dims=1)
|
||||||
|
if i == 0
|
||||||
|
p1 = scatter(dataset[1, :], dataset[2, :],
|
||||||
|
alpha=0.01,
|
||||||
|
aspectratio=:equal,
|
||||||
|
title=L"x_t",
|
||||||
|
legend=false,
|
||||||
|
)
|
||||||
|
scatter!(sample_old[1, :], sample_old[2, :])
|
||||||
|
|
||||||
|
p2 = scatter(dataset[1, :], dataset[2, :],
|
||||||
|
alpha=0.01,
|
||||||
|
aspectratio=:equal,
|
||||||
|
title=L"x_0",
|
||||||
|
legend=false,
|
||||||
|
)
|
||||||
|
|
||||||
|
l = @layout [a b]
|
||||||
|
t_str = lpad(num_timesteps, 3, "0")
|
||||||
|
plot(p1, p2,
|
||||||
|
layout=l,
|
||||||
|
plot_title=latexstring("t = $(t_str)"),
|
||||||
|
)
|
||||||
|
xlims!(-2, 2)
|
||||||
|
ylims!(-2, 2)
|
||||||
|
else
|
||||||
|
sample, x_0, timestep = predictions[i]
|
||||||
|
p1 = scatter(dataset[1, :], dataset[2, :],
|
||||||
|
alpha=0.01,
|
||||||
|
aspectratio=:equal,
|
||||||
|
legend=false,
|
||||||
|
title=L"x_t",
|
||||||
|
)
|
||||||
|
scatter!(sample[1, :], sample[2, :])
|
||||||
|
|
||||||
|
p2 = scatter(dataset[1, :], dataset[2, :],
|
||||||
|
alpha=0.01,
|
||||||
|
aspectratio=:equal,
|
||||||
|
legend=false,
|
||||||
|
title=L"x_0",
|
||||||
|
)
|
||||||
|
scatter!(x_0[1, :], x_0[2, :])
|
||||||
|
|
||||||
|
l = @layout [a b]
|
||||||
|
t_str = lpad(timestep - 1, 3, "0")
|
||||||
|
plot(p1, p2,
|
||||||
|
layout=l,
|
||||||
|
plot_title=latexstring("t = $(t_str)"),
|
||||||
|
)
|
||||||
|
xlims!(-2, 2)
|
||||||
|
ylims!(-2, 2)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
gif(anim, anim.dir * ".gif", fps=50)
|
||||||
|
|
|
@ -87,12 +87,17 @@ function add_noise(
|
||||||
ϵ::AbstractArray,
|
ϵ::AbstractArray,
|
||||||
t::AbstractArray,
|
t::AbstractArray,
|
||||||
)
|
)
|
||||||
⎷α̅ₜ = scheduler.⎷α̅[t]
|
# retreive scheduler variables at timesteps t
|
||||||
⎷β̅ₜ = scheduler.⎷β̅[t]
|
reshape_size = tuple(
|
||||||
|
fill(1, ndims(x₀) - 1)...,
|
||||||
|
size(t, 1)
|
||||||
|
)
|
||||||
|
⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size)
|
||||||
|
⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size)
|
||||||
|
|
||||||
# noisify clean data
|
# noisify clean data
|
||||||
# arxiv:2006.11239 Eq. 4
|
# arxiv:2006.11239 Eq. 4
|
||||||
xₜ = ⎷α̅ₜ' .* x₀ + ⎷β̅ₜ' .* ϵ
|
xₜ = ⎷α̅ₜ .* x₀ + ⎷β̅ₜ .* ϵ
|
||||||
|
|
||||||
return xₜ
|
return xₜ
|
||||||
end
|
end
|
||||||
|
@ -117,31 +122,35 @@ function step(
|
||||||
t::AbstractArray,
|
t::AbstractArray,
|
||||||
)
|
)
|
||||||
# retreive scheduler variables at timesteps t
|
# retreive scheduler variables at timesteps t
|
||||||
βₜ = scheduler.β[t]
|
reshape_size = tuple(
|
||||||
β̅ₜ = scheduler.β̅[t]
|
fill(1, ndims(xₜ) - 1)...,
|
||||||
β̅ₜ₋₁ = scheduler.β̅₋₁[t]
|
size(t, 1)
|
||||||
⎷αₜ = scheduler.⎷α[t]
|
)
|
||||||
⎷α̅ₜ = scheduler.⎷α̅[t]
|
βₜ = reshape(scheduler.β[t], reshape_size)
|
||||||
⎷α̅ₜ₋₁ = scheduler.⎷α̅₋₁[t]
|
β̅ₜ = reshape(scheduler.β̅[t], reshape_size)
|
||||||
⎷β̅ₜ = scheduler.⎷β̅[t]
|
β̅ₜ₋₁ = reshape(scheduler.β̅₋₁[t], reshape_size)
|
||||||
|
⎷αₜ = reshape(scheduler.⎷α[t], reshape_size)
|
||||||
|
⎷α̅ₜ = reshape(scheduler.⎷α̅[t], reshape_size)
|
||||||
|
⎷α̅ₜ₋₁ = reshape(scheduler.⎷α̅₋₁[t], reshape_size)
|
||||||
|
⎷β̅ₜ = reshape(scheduler.⎷β̅[t], reshape_size)
|
||||||
|
|
||||||
# compute predicted previous sample x̂₀
|
# compute predicted previous sample x̂₀
|
||||||
# arxiv:2006.11239 Eq. 15
|
# arxiv:2006.11239 Eq. 15
|
||||||
# arxiv:2208.11970 Eq. 115
|
# arxiv:2208.11970 Eq. 115
|
||||||
x̂₀ = (xₜ - ⎷β̅ₜ' .* ϵᵧ) ./ ⎷α̅ₜ'
|
x̂₀ = (xₜ - ⎷β̅ₜ .* ϵᵧ) ./ ⎷α̅ₜ
|
||||||
|
|
||||||
# compute predicted previous sample μ̃ₜ
|
# compute predicted previous sample μ̃ₜ
|
||||||
# arxiv:2006.11239 Eq. 7
|
# arxiv:2006.11239 Eq. 7
|
||||||
# arxiv:2208.11970 Eq. 84
|
# arxiv:2208.11970 Eq. 84
|
||||||
λ₀ = ⎷α̅ₜ₋₁ .* βₜ ./ β̅ₜ
|
λ₀ = ⎷α̅ₜ₋₁ .* βₜ ./ β̅ₜ
|
||||||
λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ # TODO: this could be stored in the scheduler
|
λₜ = ⎷αₜ .* β̅ₜ₋₁ ./ β̅ₜ # TODO: this could be stored in the scheduler
|
||||||
μ̃ₜ = λ₀' .* x̂₀ + λₜ' .* xₜ
|
μ̃ₜ = λ₀ .* x̂₀ + λₜ .* xₜ
|
||||||
|
|
||||||
# sample predicted previous sample xₜ₋₁
|
# sample predicted previous sample xₜ₋₁
|
||||||
# arxiv:2006.11239 Eq. 6
|
# arxiv:2006.11239 Eq. 6
|
||||||
# arxiv:2208.11970 Eq. 70
|
# arxiv:2208.11970 Eq. 70
|
||||||
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler
|
σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ # TODO: this could be stored in the scheduler
|
||||||
xₜ₋₁ = μ̃ₜ + σₜ' .* randn(size(ϵᵧ))
|
xₜ₋₁ = μ̃ₜ + σₜ .* randn(size(ϵᵧ))
|
||||||
|
|
||||||
return xₜ₋₁, x̂₀
|
return xₜ₋₁, x̂₀
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in a new issue