mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-11-14 00:58:13 +00:00
34 lines
1.1 KiB
Julia
34 lines
1.1 KiB
Julia
|
using Flux
|
||
|
|
||
|
struct SinusoidalPositionEmbedding{W<:AbstractArray}
|
||
|
weight::W
|
||
|
end
|
||
|
|
||
|
Flux.@functor SinusoidalPositionEmbedding
|
||
|
Flux.trainable(emb::SinusoidalPositionEmbedding) = () # mark it as an non-trainable array
|
||
|
|
||
|
function SinusoidalPositionEmbedding(in::Int, out::Int)
|
||
|
W = make_positional_embedding(out, in)
|
||
|
SinusoidalPositionEmbedding(W)
|
||
|
end
|
||
|
|
||
|
function make_positional_embedding(dim_embedding::Int, seq_length::Int=1000; n::Int=10000)
|
||
|
embedding = Matrix{Float32}(undef, dim_embedding, seq_length)
|
||
|
for pos in 1:seq_length
|
||
|
for row in 0:2:(dim_embedding-1)
|
||
|
denom = 1.0 / (n^(row / (dim_embedding - 2)))
|
||
|
embedding[row+1, pos] = sin(pos * denom)
|
||
|
embedding[row+2, pos] = cos(pos * denom)
|
||
|
end
|
||
|
end
|
||
|
embedding
|
||
|
end
|
||
|
|
||
|
(m::SinusoidalPositionEmbedding)(x::Integer) = m.weight[:, x]
|
||
|
(m::SinusoidalPositionEmbedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
|
||
|
(m::SinusoidalPositionEmbedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
|
||
|
|
||
|
function Base.show(io::IO, m::SinusoidalPositionEmbedding)
|
||
|
print(io, "SinusoidalPositionEmbedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
|
||
|
end
|