mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
cosmetic changes for text_encoder.py
This commit is contained in:
parent
b8e7179447
commit
9d663534d1
|
@ -1,21 +1,9 @@
|
|||
from torch import Tensor, arange, device as Device, dtype as DType
|
||||
|
||||
from refiners.fluxion.layers import (
|
||||
ApproximateGeLU,
|
||||
GeLU,
|
||||
Linear,
|
||||
LayerNorm,
|
||||
Embedding,
|
||||
Chain,
|
||||
Sum,
|
||||
SelfAttention,
|
||||
Lambda,
|
||||
Residual,
|
||||
)
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
|
||||
|
||||
class PositionalTokenEncoder(Sum):
|
||||
class PositionalTokenEncoder(fl.Sum):
|
||||
structural_attrs = ["vocabulary_size", "positional_embedding_dim"]
|
||||
|
||||
def __init__(
|
||||
|
@ -29,15 +17,15 @@ class PositionalTokenEncoder(Sum):
|
|||
self.vocabulary_size = vocabulary_size
|
||||
self.positional_embedding_dim = positional_embedding_dim
|
||||
super().__init__(
|
||||
Embedding(
|
||||
fl.Embedding(
|
||||
num_embeddings=vocabulary_size,
|
||||
embedding_dim=embedding_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
Chain(
|
||||
Lambda(func=self.get_position_ids),
|
||||
Embedding(
|
||||
fl.Chain(
|
||||
fl.Lambda(func=self.get_position_ids),
|
||||
fl.Embedding(
|
||||
num_embeddings=positional_embedding_dim,
|
||||
embedding_dim=embedding_dim,
|
||||
device=device,
|
||||
|
@ -54,7 +42,7 @@ class PositionalTokenEncoder(Sum):
|
|||
return self.position_ids[:, : x.shape[1]]
|
||||
|
||||
|
||||
class FeedForward(Chain):
|
||||
class FeedForward(fl.Chain):
|
||||
structural_attrs = ["embedding_dim", "feedforward_dim"]
|
||||
|
||||
def __init__(
|
||||
|
@ -67,13 +55,13 @@ class FeedForward(Chain):
|
|||
self.embedding_dim = embedding_dim
|
||||
self.feedforward_dim = feedforward_dim
|
||||
super().__init__(
|
||||
Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype),
|
||||
GeLU(),
|
||||
Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||
fl.Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype),
|
||||
fl.GeLU(),
|
||||
fl.Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
|
||||
class TransformerLayer(Chain):
|
||||
class TransformerLayer(fl.Chain):
|
||||
structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"]
|
||||
|
||||
def __init__(
|
||||
|
@ -90,14 +78,14 @@ class TransformerLayer(Chain):
|
|||
self.feedforward_dim = feedforward_dim
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
super().__init__(
|
||||
Residual(
|
||||
LayerNorm(
|
||||
fl.Residual(
|
||||
fl.LayerNorm(
|
||||
normalized_shape=embedding_dim,
|
||||
eps=layer_norm_eps,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
SelfAttention(
|
||||
fl.SelfAttention(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=num_attention_heads,
|
||||
is_causal=True,
|
||||
|
@ -105,8 +93,8 @@ class TransformerLayer(Chain):
|
|||
dtype=dtype,
|
||||
),
|
||||
),
|
||||
Residual(
|
||||
LayerNorm(
|
||||
fl.Residual(
|
||||
fl.LayerNorm(
|
||||
normalized_shape=embedding_dim,
|
||||
eps=layer_norm_eps,
|
||||
device=device,
|
||||
|
@ -122,7 +110,7 @@ class TransformerLayer(Chain):
|
|||
)
|
||||
|
||||
|
||||
class CLIPTextEncoder(Chain):
|
||||
class CLIPTextEncoder(fl.Chain):
|
||||
structural_attrs = [
|
||||
"embedding_dim",
|
||||
"positional_embedding_dim",
|
||||
|
@ -177,11 +165,11 @@ class CLIPTextEncoder(Chain):
|
|||
)
|
||||
for _ in range(num_layers)
|
||||
),
|
||||
LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||
)
|
||||
if use_quick_gelu:
|
||||
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)):
|
||||
parent.replace(old_module=gelu, new_module=ApproximateGeLU())
|
||||
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
|
||||
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
|
||||
|
||||
def encode(self, text: str) -> Tensor:
|
||||
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)
|
||||
|
|
Loading…
Reference in a new issue