diff --git a/src/refiners/foundationals/clip/text_encoder.py b/src/refiners/foundationals/clip/text_encoder.py index c35c912..92038ea 100644 --- a/src/refiners/foundationals/clip/text_encoder.py +++ b/src/refiners/foundationals/clip/text_encoder.py @@ -1,21 +1,9 @@ from torch import Tensor, arange, device as Device, dtype as DType - -from refiners.fluxion.layers import ( - ApproximateGeLU, - GeLU, - Linear, - LayerNorm, - Embedding, - Chain, - Sum, - SelfAttention, - Lambda, - Residual, -) +import refiners.fluxion.layers as fl from refiners.foundationals.clip.tokenizer import CLIPTokenizer -class PositionalTokenEncoder(Sum): +class PositionalTokenEncoder(fl.Sum): structural_attrs = ["vocabulary_size", "positional_embedding_dim"] def __init__( @@ -29,15 +17,15 @@ class PositionalTokenEncoder(Sum): self.vocabulary_size = vocabulary_size self.positional_embedding_dim = positional_embedding_dim super().__init__( - Embedding( + fl.Embedding( num_embeddings=vocabulary_size, embedding_dim=embedding_dim, device=device, dtype=dtype, ), - Chain( - Lambda(func=self.get_position_ids), - Embedding( + fl.Chain( + fl.Lambda(func=self.get_position_ids), + fl.Embedding( num_embeddings=positional_embedding_dim, embedding_dim=embedding_dim, device=device, @@ -54,7 +42,7 @@ class PositionalTokenEncoder(Sum): return self.position_ids[:, : x.shape[1]] -class FeedForward(Chain): +class FeedForward(fl.Chain): structural_attrs = ["embedding_dim", "feedforward_dim"] def __init__( @@ -67,13 +55,13 @@ class FeedForward(Chain): self.embedding_dim = embedding_dim self.feedforward_dim = feedforward_dim super().__init__( - Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype), - GeLU(), - Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype), + fl.Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype), + fl.GeLU(), + fl.Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype), ) -class TransformerLayer(Chain): +class TransformerLayer(fl.Chain): structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"] def __init__( @@ -90,14 +78,14 @@ class TransformerLayer(Chain): self.feedforward_dim = feedforward_dim self.layer_norm_eps = layer_norm_eps super().__init__( - Residual( - LayerNorm( + fl.Residual( + fl.LayerNorm( normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype, ), - SelfAttention( + fl.SelfAttention( embedding_dim=embedding_dim, num_heads=num_attention_heads, is_causal=True, @@ -105,8 +93,8 @@ class TransformerLayer(Chain): dtype=dtype, ), ), - Residual( - LayerNorm( + fl.Residual( + fl.LayerNorm( normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, @@ -122,7 +110,7 @@ class TransformerLayer(Chain): ) -class CLIPTextEncoder(Chain): +class CLIPTextEncoder(fl.Chain): structural_attrs = [ "embedding_dim", "positional_embedding_dim", @@ -177,11 +165,11 @@ class CLIPTextEncoder(Chain): ) for _ in range(num_layers) ), - LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype), + fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype), ) if use_quick_gelu: - for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)): - parent.replace(old_module=gelu, new_module=ApproximateGeLU()) + for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)): + parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU()) def encode(self, text: str) -> Tensor: tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)