🚚 move ConditionalChain.jl and Embeddings.jl files out of Diffusers.jl modules

This commit is contained in:
Laureηt 2023-07-24 21:05:47 +02:00
parent a564a8f6f6
commit 16a1424151
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
4 changed files with 8 additions and 8 deletions

View file

@ -4,6 +4,10 @@ using Random
using Plots
using ProgressMeter
# utils
include("Embeddings.jl")
include("ConditionalChain.jl")
function make_spiral(rng::AbstractRNG, n_samples::Int=1000)
t_min = 1.5π
t_max = 4.5π
@ -70,12 +74,12 @@ end
gif(anim, "swissroll.gif", fps=50)
d_hid = 32
model = Diffusers.ConditionalChain(
model = ConditionalChain(
Parallel(
.+,
Dense(2, d_hid),
Chain(
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid),
SinusoidalPositionEmbedding(num_timesteps, d_hid),
Dense(d_hid, d_hid))
),
relu,
@ -83,7 +87,7 @@ model = Diffusers.ConditionalChain(
.+,
Dense(d_hid, d_hid),
Chain(
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid),
SinusoidalPositionEmbedding(num_timesteps, d_hid),
Dense(d_hid, d_hid))
),
relu,
@ -91,7 +95,7 @@ model = Diffusers.ConditionalChain(
.+,
Dense(d_hid, d_hid),
Chain(
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid),
SinusoidalPositionEmbedding(num_timesteps, d_hid),
Dense(d_hid, d_hid))
),
relu,

View file

@ -1,9 +1,5 @@
module Diffusers
# utils
include("Embeddings.jl")
include("ConditionalChain.jl")
# abtract types
include("Schedulers.jl")
include("BetaSchedulers.jl")