🚚 move ConditionalChain.jl and Embeddings.jl files out of Diffusers.jl modules

This commit is contained in:
Laureηt 2023-07-24 21:05:47 +02:00
parent a564a8f6f6
commit 16a1424151
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
4 changed files with 8 additions and 8 deletions

View file

@ -4,6 +4,10 @@ using Random
using Plots using Plots
using ProgressMeter using ProgressMeter
# utils
include("Embeddings.jl")
include("ConditionalChain.jl")
function make_spiral(rng::AbstractRNG, n_samples::Int=1000) function make_spiral(rng::AbstractRNG, n_samples::Int=1000)
t_min = 1.5π t_min = 1.5π
t_max = 4.5π t_max = 4.5π
@ -70,12 +74,12 @@ end
gif(anim, "swissroll.gif", fps=50) gif(anim, "swissroll.gif", fps=50)
d_hid = 32 d_hid = 32
model = Diffusers.ConditionalChain( model = ConditionalChain(
Parallel( Parallel(
.+, .+,
Dense(2, d_hid), Dense(2, d_hid),
Chain( Chain(
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid), SinusoidalPositionEmbedding(num_timesteps, d_hid),
Dense(d_hid, d_hid)) Dense(d_hid, d_hid))
), ),
relu, relu,
@ -83,7 +87,7 @@ model = Diffusers.ConditionalChain(
.+, .+,
Dense(d_hid, d_hid), Dense(d_hid, d_hid),
Chain( Chain(
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid), SinusoidalPositionEmbedding(num_timesteps, d_hid),
Dense(d_hid, d_hid)) Dense(d_hid, d_hid))
), ),
relu, relu,
@ -91,7 +95,7 @@ model = Diffusers.ConditionalChain(
.+, .+,
Dense(d_hid, d_hid), Dense(d_hid, d_hid),
Chain( Chain(
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid), SinusoidalPositionEmbedding(num_timesteps, d_hid),
Dense(d_hid, d_hid)) Dense(d_hid, d_hid))
), ),
relu, relu,

View file

@ -1,9 +1,5 @@
module Diffusers module Diffusers
# utils
include("Embeddings.jl")
include("ConditionalChain.jl")
# abtract types # abtract types
include("Schedulers.jl") include("Schedulers.jl")
include("BetaSchedulers.jl") include("BetaSchedulers.jl")