add use_quick_gelu kwarg for CLIPTextEncoder

This commit is contained in:
limiteinductive 2023-08-16 13:38:09 +02:00 committed by Benjamin Trom
parent efe923a272
commit 63fda2bfd8

View file

@ -131,6 +131,7 @@ class CLIPTextEncoder(Chain):
"num_attention_heads",
"feedforward_dim",
"layer_norm_eps",
"use_quick_gelu",
"tokenizer",
]
@ -143,6 +144,7 @@ class CLIPTextEncoder(Chain):
num_attention_heads: int = 12,
feedforward_dim: int = 3072,
layer_norm_eps: float = 1e-5,
use_quick_gelu: bool = False,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
@ -153,6 +155,7 @@ class CLIPTextEncoder(Chain):
self.num_attention_heads = num_attention_heads
self.feedforward_dim = feedforward_dim
self.layer_norm_eps = layer_norm_eps
self.use_quick_gelu = use_quick_gelu
self.tokenizer = CLIPTokenizer()
super().__init__(
PositionalTokenEncoder(
@ -175,6 +178,9 @@ class CLIPTextEncoder(Chain):
),
LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
)
if use_quick_gelu:
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)):
parent.replace(old_module=gelu, new_module=ApproximateGeLU())
def encode(self, text: str) -> Tensor:
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)
@ -192,6 +198,7 @@ class CLIPTextEncoderL(CLIPTextEncoder):
num_layers=12
num_attention_heads=12
feedforward_dim=3072
use_quick_gelu=True
We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation
of OpenAI (https://github.com/openai/CLIP/blob/main/clip/model.py#L166)
@ -203,11 +210,10 @@ class CLIPTextEncoderL(CLIPTextEncoder):
num_layers=12,
num_attention_heads=12,
feedforward_dim=3072,
use_quick_gelu=True,
device=device,
dtype=dtype,
)
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)):
parent.replace(old_module=gelu, new_module=ApproximateGeLU())
class CLIPTextEncoderH(CLIPTextEncoder):