From 16a1424151d81c6ec3d26b69e9dc1ed36ff0dbe6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Mon, 24 Jul 2023 21:05:47 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=9A=20move=20ConditionalChain.jl=20and?= =?UTF-8?q?=20Embeddings.jl=20files=20out=20of=20Diffusers.jl=20modules?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- {src => examples}/ConditionalChain.jl | 0 {src => examples}/Embeddings.jl | 0 examples/swissroll.jl | 12 ++++++++---- src/Diffusers.jl | 4 ---- 4 files changed, 8 insertions(+), 8 deletions(-) rename {src => examples}/ConditionalChain.jl (100%) rename {src => examples}/Embeddings.jl (100%) diff --git a/src/ConditionalChain.jl b/examples/ConditionalChain.jl similarity index 100% rename from src/ConditionalChain.jl rename to examples/ConditionalChain.jl diff --git a/src/Embeddings.jl b/examples/Embeddings.jl similarity index 100% rename from src/Embeddings.jl rename to examples/Embeddings.jl diff --git a/examples/swissroll.jl b/examples/swissroll.jl index ea29357..2dfb83a 100644 --- a/examples/swissroll.jl +++ b/examples/swissroll.jl @@ -4,6 +4,10 @@ using Random using Plots using ProgressMeter +# utils +include("Embeddings.jl") +include("ConditionalChain.jl") + function make_spiral(rng::AbstractRNG, n_samples::Int=1000) t_min = 1.5π t_max = 4.5π @@ -70,12 +74,12 @@ end gif(anim, "swissroll.gif", fps=50) d_hid = 32 -model = Diffusers.ConditionalChain( +model = ConditionalChain( Parallel( .+, Dense(2, d_hid), Chain( - Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid), + SinusoidalPositionEmbedding(num_timesteps, d_hid), Dense(d_hid, d_hid)) ), relu, @@ -83,7 +87,7 @@ model = Diffusers.ConditionalChain( .+, Dense(d_hid, d_hid), Chain( - Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid), + SinusoidalPositionEmbedding(num_timesteps, d_hid), Dense(d_hid, d_hid)) ), relu, @@ -91,7 +95,7 @@ model = Diffusers.ConditionalChain( .+, Dense(d_hid, d_hid), Chain( - Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid), + SinusoidalPositionEmbedding(num_timesteps, d_hid), Dense(d_hid, d_hid)) ), relu, diff --git a/src/Diffusers.jl b/src/Diffusers.jl index 6f7aa98..c15a5f1 100644 --- a/src/Diffusers.jl +++ b/src/Diffusers.jl @@ -1,9 +1,5 @@ module Diffusers -# utils -include("Embeddings.jl") -include("ConditionalChain.jl") - # abtract types include("Schedulers.jl") include("BetaSchedulers.jl")