add use_quick_gelu kwarg for CLIPTextEncoder

This commit is contained in:
limiteinductive 2023-08-16 13:38:09 +02:00 committed by Benjamin Trom
parent efe923a272
commit 63fda2bfd8

View file

@ -131,6 +131,7 @@ class CLIPTextEncoder(Chain):
"num_attention_heads", "num_attention_heads",
"feedforward_dim", "feedforward_dim",
"layer_norm_eps", "layer_norm_eps",
"use_quick_gelu",
"tokenizer", "tokenizer",
] ]
@ -143,6 +144,7 @@ class CLIPTextEncoder(Chain):
num_attention_heads: int = 12, num_attention_heads: int = 12,
feedforward_dim: int = 3072, feedforward_dim: int = 3072,
layer_norm_eps: float = 1e-5, layer_norm_eps: float = 1e-5,
use_quick_gelu: bool = False,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
@ -153,6 +155,7 @@ class CLIPTextEncoder(Chain):
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.feedforward_dim = feedforward_dim self.feedforward_dim = feedforward_dim
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.use_quick_gelu = use_quick_gelu
self.tokenizer = CLIPTokenizer() self.tokenizer = CLIPTokenizer()
super().__init__( super().__init__(
PositionalTokenEncoder( PositionalTokenEncoder(
@ -175,6 +178,9 @@ class CLIPTextEncoder(Chain):
), ),
LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype), LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
) )
if use_quick_gelu:
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)):
parent.replace(old_module=gelu, new_module=ApproximateGeLU())
def encode(self, text: str) -> Tensor: def encode(self, text: str) -> Tensor:
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device) tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)
@ -192,6 +198,7 @@ class CLIPTextEncoderL(CLIPTextEncoder):
num_layers=12 num_layers=12
num_attention_heads=12 num_attention_heads=12
feedforward_dim=3072 feedforward_dim=3072
use_quick_gelu=True
We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation
of OpenAI (https://github.com/openai/CLIP/blob/main/clip/model.py#L166) of OpenAI (https://github.com/openai/CLIP/blob/main/clip/model.py#L166)
@ -203,11 +210,10 @@ class CLIPTextEncoderL(CLIPTextEncoder):
num_layers=12, num_layers=12,
num_attention_heads=12, num_attention_heads=12,
feedforward_dim=3072, feedforward_dim=3072,
use_quick_gelu=True,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)):
parent.replace(old_module=gelu, new_module=ApproximateGeLU())
class CLIPTextEncoderH(CLIPTextEncoder): class CLIPTextEncoderH(CLIPTextEncoder):