mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
cosmetic changes for text_encoder.py
This commit is contained in:
parent
b8e7179447
commit
9d663534d1
|
@ -1,21 +1,9 @@
|
||||||
from torch import Tensor, arange, device as Device, dtype as DType
|
from torch import Tensor, arange, device as Device, dtype as DType
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.layers import (
|
|
||||||
ApproximateGeLU,
|
|
||||||
GeLU,
|
|
||||||
Linear,
|
|
||||||
LayerNorm,
|
|
||||||
Embedding,
|
|
||||||
Chain,
|
|
||||||
Sum,
|
|
||||||
SelfAttention,
|
|
||||||
Lambda,
|
|
||||||
Residual,
|
|
||||||
)
|
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
|
||||||
|
|
||||||
class PositionalTokenEncoder(Sum):
|
class PositionalTokenEncoder(fl.Sum):
|
||||||
structural_attrs = ["vocabulary_size", "positional_embedding_dim"]
|
structural_attrs = ["vocabulary_size", "positional_embedding_dim"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -29,15 +17,15 @@ class PositionalTokenEncoder(Sum):
|
||||||
self.vocabulary_size = vocabulary_size
|
self.vocabulary_size = vocabulary_size
|
||||||
self.positional_embedding_dim = positional_embedding_dim
|
self.positional_embedding_dim = positional_embedding_dim
|
||||||
super().__init__(
|
super().__init__(
|
||||||
Embedding(
|
fl.Embedding(
|
||||||
num_embeddings=vocabulary_size,
|
num_embeddings=vocabulary_size,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
Chain(
|
fl.Chain(
|
||||||
Lambda(func=self.get_position_ids),
|
fl.Lambda(func=self.get_position_ids),
|
||||||
Embedding(
|
fl.Embedding(
|
||||||
num_embeddings=positional_embedding_dim,
|
num_embeddings=positional_embedding_dim,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -54,7 +42,7 @@ class PositionalTokenEncoder(Sum):
|
||||||
return self.position_ids[:, : x.shape[1]]
|
return self.position_ids[:, : x.shape[1]]
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(Chain):
|
class FeedForward(fl.Chain):
|
||||||
structural_attrs = ["embedding_dim", "feedforward_dim"]
|
structural_attrs = ["embedding_dim", "feedforward_dim"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -67,13 +55,13 @@ class FeedForward(Chain):
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.feedforward_dim = feedforward_dim
|
self.feedforward_dim = feedforward_dim
|
||||||
super().__init__(
|
super().__init__(
|
||||||
Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype),
|
fl.Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype),
|
||||||
GeLU(),
|
fl.GeLU(),
|
||||||
Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
fl.Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransformerLayer(Chain):
|
class TransformerLayer(fl.Chain):
|
||||||
structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"]
|
structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -90,14 +78,14 @@ class TransformerLayer(Chain):
|
||||||
self.feedforward_dim = feedforward_dim
|
self.feedforward_dim = feedforward_dim
|
||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
super().__init__(
|
super().__init__(
|
||||||
Residual(
|
fl.Residual(
|
||||||
LayerNorm(
|
fl.LayerNorm(
|
||||||
normalized_shape=embedding_dim,
|
normalized_shape=embedding_dim,
|
||||||
eps=layer_norm_eps,
|
eps=layer_norm_eps,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
SelfAttention(
|
fl.SelfAttention(
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
num_heads=num_attention_heads,
|
num_heads=num_attention_heads,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
|
@ -105,8 +93,8 @@ class TransformerLayer(Chain):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
Residual(
|
fl.Residual(
|
||||||
LayerNorm(
|
fl.LayerNorm(
|
||||||
normalized_shape=embedding_dim,
|
normalized_shape=embedding_dim,
|
||||||
eps=layer_norm_eps,
|
eps=layer_norm_eps,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -122,7 +110,7 @@ class TransformerLayer(Chain):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncoder(Chain):
|
class CLIPTextEncoder(fl.Chain):
|
||||||
structural_attrs = [
|
structural_attrs = [
|
||||||
"embedding_dim",
|
"embedding_dim",
|
||||||
"positional_embedding_dim",
|
"positional_embedding_dim",
|
||||||
|
@ -177,11 +165,11 @@ class CLIPTextEncoder(Chain):
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
),
|
),
|
||||||
LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||||
)
|
)
|
||||||
if use_quick_gelu:
|
if use_quick_gelu:
|
||||||
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)):
|
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
|
||||||
parent.replace(old_module=gelu, new_module=ApproximateGeLU())
|
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
|
||||||
|
|
||||||
def encode(self, text: str) -> Tensor:
|
def encode(self, text: str) -> Tensor:
|
||||||
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)
|
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)
|
||||||
|
|
Loading…
Reference in a new issue