diff --git a/src/refiners/foundationals/clip/text_encoder.py b/src/refiners/foundationals/clip/text_encoder.py index 90d35ba..188df0b 100644 --- a/src/refiners/foundationals/clip/text_encoder.py +++ b/src/refiners/foundationals/clip/text_encoder.py @@ -131,6 +131,7 @@ class CLIPTextEncoder(Chain): "num_attention_heads", "feedforward_dim", "layer_norm_eps", + "use_quick_gelu", "tokenizer", ] @@ -143,6 +144,7 @@ class CLIPTextEncoder(Chain): num_attention_heads: int = 12, feedforward_dim: int = 3072, layer_norm_eps: float = 1e-5, + use_quick_gelu: bool = False, device: Device | str | None = None, dtype: DType | None = None, ) -> None: @@ -153,6 +155,7 @@ class CLIPTextEncoder(Chain): self.num_attention_heads = num_attention_heads self.feedforward_dim = feedforward_dim self.layer_norm_eps = layer_norm_eps + self.use_quick_gelu = use_quick_gelu self.tokenizer = CLIPTokenizer() super().__init__( PositionalTokenEncoder( @@ -175,6 +178,9 @@ class CLIPTextEncoder(Chain): ), LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype), ) + if use_quick_gelu: + for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)): + parent.replace(old_module=gelu, new_module=ApproximateGeLU()) def encode(self, text: str) -> Tensor: tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device) @@ -192,6 +198,7 @@ class CLIPTextEncoderL(CLIPTextEncoder): num_layers=12 num_attention_heads=12 feedforward_dim=3072 + use_quick_gelu=True We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation of OpenAI (https://github.com/openai/CLIP/blob/main/clip/model.py#L166) @@ -203,11 +210,10 @@ class CLIPTextEncoderL(CLIPTextEncoder): num_layers=12, num_attention_heads=12, feedforward_dim=3072, + use_quick_gelu=True, device=device, dtype=dtype, ) - for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)): - parent.replace(old_module=gelu, new_module=ApproximateGeLU()) class CLIPTextEncoderH(CLIPTextEncoder):