cosmetic changes for text_encoder.py

This commit is contained in:
limiteinductive 2023-08-16 15:19:08 +02:00 committed by Benjamin Trom
parent b8e7179447
commit 9d663534d1

View file

@ -1,21 +1,9 @@
from torch import Tensor, arange, device as Device, dtype as DType
from refiners.fluxion.layers import (
ApproximateGeLU,
GeLU,
Linear,
LayerNorm,
Embedding,
Chain,
Sum,
SelfAttention,
Lambda,
Residual,
)
import refiners.fluxion.layers as fl
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
class PositionalTokenEncoder(Sum):
class PositionalTokenEncoder(fl.Sum):
structural_attrs = ["vocabulary_size", "positional_embedding_dim"]
def __init__(
@ -29,15 +17,15 @@ class PositionalTokenEncoder(Sum):
self.vocabulary_size = vocabulary_size
self.positional_embedding_dim = positional_embedding_dim
super().__init__(
Embedding(
fl.Embedding(
num_embeddings=vocabulary_size,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
Chain(
Lambda(func=self.get_position_ids),
Embedding(
fl.Chain(
fl.Lambda(func=self.get_position_ids),
fl.Embedding(
num_embeddings=positional_embedding_dim,
embedding_dim=embedding_dim,
device=device,
@ -54,7 +42,7 @@ class PositionalTokenEncoder(Sum):
return self.position_ids[:, : x.shape[1]]
class FeedForward(Chain):
class FeedForward(fl.Chain):
structural_attrs = ["embedding_dim", "feedforward_dim"]
def __init__(
@ -67,13 +55,13 @@ class FeedForward(Chain):
self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim
super().__init__(
Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype),
GeLU(),
Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype),
fl.Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype),
fl.GeLU(),
fl.Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype),
)
class TransformerLayer(Chain):
class TransformerLayer(fl.Chain):
structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"]
def __init__(
@ -90,14 +78,14 @@ class TransformerLayer(Chain):
self.feedforward_dim = feedforward_dim
self.layer_norm_eps = layer_norm_eps
super().__init__(
Residual(
LayerNorm(
fl.Residual(
fl.LayerNorm(
normalized_shape=embedding_dim,
eps=layer_norm_eps,
device=device,
dtype=dtype,
),
SelfAttention(
fl.SelfAttention(
embedding_dim=embedding_dim,
num_heads=num_attention_heads,
is_causal=True,
@ -105,8 +93,8 @@ class TransformerLayer(Chain):
dtype=dtype,
),
),
Residual(
LayerNorm(
fl.Residual(
fl.LayerNorm(
normalized_shape=embedding_dim,
eps=layer_norm_eps,
device=device,
@ -122,7 +110,7 @@ class TransformerLayer(Chain):
)
class CLIPTextEncoder(Chain):
class CLIPTextEncoder(fl.Chain):
structural_attrs = [
"embedding_dim",
"positional_embedding_dim",
@ -177,11 +165,11 @@ class CLIPTextEncoder(Chain):
)
for _ in range(num_layers)
),
LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
)
if use_quick_gelu:
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)):
parent.replace(old_module=gelu, new_module=ApproximateGeLU())
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
def encode(self, text: str) -> Tensor:
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)