Diffusers.jl/examples/Embeddings.jl

34 lines
1.1 KiB
Julia
Raw Normal View History

using Flux
struct SinusoidalPositionEmbedding{W<:AbstractArray}
weight::W
end
Flux.@functor SinusoidalPositionEmbedding
Flux.trainable(emb::SinusoidalPositionEmbedding) = () # mark it as an non-trainable array
function SinusoidalPositionEmbedding(in::Int, out::Int)
W = make_positional_embedding(out, in)
SinusoidalPositionEmbedding(W)
end
function make_positional_embedding(dim_embedding::Int, seq_length::Int=1000; n::Int=10000)
embedding = Matrix{Float32}(undef, dim_embedding, seq_length)
for pos in 1:seq_length
for row in 0:2:(dim_embedding-1)
denom = 1.0 / (n^(row / (dim_embedding - 2)))
embedding[row+1, pos] = sin(pos * denom)
embedding[row+2, pos] = cos(pos * denom)
end
end
embedding
end
(m::SinusoidalPositionEmbedding)(x::Integer) = m.weight[:, x]
(m::SinusoidalPositionEmbedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::SinusoidalPositionEmbedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
function Base.show(io::IO, m::SinusoidalPositionEmbedding)
print(io, "SinusoidalPositionEmbedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end