cosmetic changes for text_encoder.py

This commit is contained in:
limiteinductive 2023-08-16 15:19:08 +02:00 committed by Benjamin Trom
parent b8e7179447
commit 9d663534d1

View file

@ -1,21 +1,9 @@
from torch import Tensor, arange, device as Device, dtype as DType from torch import Tensor, arange, device as Device, dtype as DType
import refiners.fluxion.layers as fl
from refiners.fluxion.layers import (
ApproximateGeLU,
GeLU,
Linear,
LayerNorm,
Embedding,
Chain,
Sum,
SelfAttention,
Lambda,
Residual,
)
from refiners.foundationals.clip.tokenizer import CLIPTokenizer from refiners.foundationals.clip.tokenizer import CLIPTokenizer
class PositionalTokenEncoder(Sum): class PositionalTokenEncoder(fl.Sum):
structural_attrs = ["vocabulary_size", "positional_embedding_dim"] structural_attrs = ["vocabulary_size", "positional_embedding_dim"]
def __init__( def __init__(
@ -29,15 +17,15 @@ class PositionalTokenEncoder(Sum):
self.vocabulary_size = vocabulary_size self.vocabulary_size = vocabulary_size
self.positional_embedding_dim = positional_embedding_dim self.positional_embedding_dim = positional_embedding_dim
super().__init__( super().__init__(
Embedding( fl.Embedding(
num_embeddings=vocabulary_size, num_embeddings=vocabulary_size,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),
Chain( fl.Chain(
Lambda(func=self.get_position_ids), fl.Lambda(func=self.get_position_ids),
Embedding( fl.Embedding(
num_embeddings=positional_embedding_dim, num_embeddings=positional_embedding_dim,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
device=device, device=device,
@ -54,7 +42,7 @@ class PositionalTokenEncoder(Sum):
return self.position_ids[:, : x.shape[1]] return self.position_ids[:, : x.shape[1]]
class FeedForward(Chain): class FeedForward(fl.Chain):
structural_attrs = ["embedding_dim", "feedforward_dim"] structural_attrs = ["embedding_dim", "feedforward_dim"]
def __init__( def __init__(
@ -67,13 +55,13 @@ class FeedForward(Chain):
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim self.feedforward_dim = feedforward_dim
super().__init__( super().__init__(
Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype), fl.Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype),
GeLU(), fl.GeLU(),
Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype), fl.Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype),
) )
class TransformerLayer(Chain): class TransformerLayer(fl.Chain):
structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"] structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"]
def __init__( def __init__(
@ -90,14 +78,14 @@ class TransformerLayer(Chain):
self.feedforward_dim = feedforward_dim self.feedforward_dim = feedforward_dim
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
super().__init__( super().__init__(
Residual( fl.Residual(
LayerNorm( fl.LayerNorm(
normalized_shape=embedding_dim, normalized_shape=embedding_dim,
eps=layer_norm_eps, eps=layer_norm_eps,
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),
SelfAttention( fl.SelfAttention(
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
num_heads=num_attention_heads, num_heads=num_attention_heads,
is_causal=True, is_causal=True,
@ -105,8 +93,8 @@ class TransformerLayer(Chain):
dtype=dtype, dtype=dtype,
), ),
), ),
Residual( fl.Residual(
LayerNorm( fl.LayerNorm(
normalized_shape=embedding_dim, normalized_shape=embedding_dim,
eps=layer_norm_eps, eps=layer_norm_eps,
device=device, device=device,
@ -122,7 +110,7 @@ class TransformerLayer(Chain):
) )
class CLIPTextEncoder(Chain): class CLIPTextEncoder(fl.Chain):
structural_attrs = [ structural_attrs = [
"embedding_dim", "embedding_dim",
"positional_embedding_dim", "positional_embedding_dim",
@ -177,11 +165,11 @@ class CLIPTextEncoder(Chain):
) )
for _ in range(num_layers) for _ in range(num_layers)
), ),
LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype), fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
) )
if use_quick_gelu: if use_quick_gelu:
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)): for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
parent.replace(old_module=gelu, new_module=ApproximateGeLU()) parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
def encode(self, text: str) -> Tensor: def encode(self, text: str) -> Tensor:
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device) tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)