mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-11-09 15:02:02 +00:00
🚚 move ConditionalChain.jl and Embeddings.jl files out of Diffusers.jl modules
This commit is contained in:
parent
a564a8f6f6
commit
16a1424151
|
@ -4,6 +4,10 @@ using Random
|
||||||
using Plots
|
using Plots
|
||||||
using ProgressMeter
|
using ProgressMeter
|
||||||
|
|
||||||
|
# utils
|
||||||
|
include("Embeddings.jl")
|
||||||
|
include("ConditionalChain.jl")
|
||||||
|
|
||||||
function make_spiral(rng::AbstractRNG, n_samples::Int=1000)
|
function make_spiral(rng::AbstractRNG, n_samples::Int=1000)
|
||||||
t_min = 1.5π
|
t_min = 1.5π
|
||||||
t_max = 4.5π
|
t_max = 4.5π
|
||||||
|
@ -70,12 +74,12 @@ end
|
||||||
gif(anim, "swissroll.gif", fps=50)
|
gif(anim, "swissroll.gif", fps=50)
|
||||||
|
|
||||||
d_hid = 32
|
d_hid = 32
|
||||||
model = Diffusers.ConditionalChain(
|
model = ConditionalChain(
|
||||||
Parallel(
|
Parallel(
|
||||||
.+,
|
.+,
|
||||||
Dense(2, d_hid),
|
Dense(2, d_hid),
|
||||||
Chain(
|
Chain(
|
||||||
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
||||||
Dense(d_hid, d_hid))
|
Dense(d_hid, d_hid))
|
||||||
),
|
),
|
||||||
relu,
|
relu,
|
||||||
|
@ -83,7 +87,7 @@ model = Diffusers.ConditionalChain(
|
||||||
.+,
|
.+,
|
||||||
Dense(d_hid, d_hid),
|
Dense(d_hid, d_hid),
|
||||||
Chain(
|
Chain(
|
||||||
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
||||||
Dense(d_hid, d_hid))
|
Dense(d_hid, d_hid))
|
||||||
),
|
),
|
||||||
relu,
|
relu,
|
||||||
|
@ -91,7 +95,7 @@ model = Diffusers.ConditionalChain(
|
||||||
.+,
|
.+,
|
||||||
Dense(d_hid, d_hid),
|
Dense(d_hid, d_hid),
|
||||||
Chain(
|
Chain(
|
||||||
Diffusers.SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
SinusoidalPositionEmbedding(num_timesteps, d_hid),
|
||||||
Dense(d_hid, d_hid))
|
Dense(d_hid, d_hid))
|
||||||
),
|
),
|
||||||
relu,
|
relu,
|
||||||
|
|
|
@ -1,9 +1,5 @@
|
||||||
module Diffusers
|
module Diffusers
|
||||||
|
|
||||||
# utils
|
|
||||||
include("Embeddings.jl")
|
|
||||||
include("ConditionalChain.jl")
|
|
||||||
|
|
||||||
# abtract types
|
# abtract types
|
||||||
include("Schedulers.jl")
|
include("Schedulers.jl")
|
||||||
include("BetaSchedulers.jl")
|
include("BetaSchedulers.jl")
|
||||||
|
|
Loading…
Reference in a new issue