mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-13 00:28:14 +00:00
add use_quick_gelu kwarg for CLIPTextEncoder
This commit is contained in:
parent
efe923a272
commit
63fda2bfd8
|
@ -131,6 +131,7 @@ class CLIPTextEncoder(Chain):
|
||||||
"num_attention_heads",
|
"num_attention_heads",
|
||||||
"feedforward_dim",
|
"feedforward_dim",
|
||||||
"layer_norm_eps",
|
"layer_norm_eps",
|
||||||
|
"use_quick_gelu",
|
||||||
"tokenizer",
|
"tokenizer",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -143,6 +144,7 @@ class CLIPTextEncoder(Chain):
|
||||||
num_attention_heads: int = 12,
|
num_attention_heads: int = 12,
|
||||||
feedforward_dim: int = 3072,
|
feedforward_dim: int = 3072,
|
||||||
layer_norm_eps: float = 1e-5,
|
layer_norm_eps: float = 1e-5,
|
||||||
|
use_quick_gelu: bool = False,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -153,6 +155,7 @@ class CLIPTextEncoder(Chain):
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
self.feedforward_dim = feedforward_dim
|
self.feedforward_dim = feedforward_dim
|
||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.use_quick_gelu = use_quick_gelu
|
||||||
self.tokenizer = CLIPTokenizer()
|
self.tokenizer = CLIPTokenizer()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
PositionalTokenEncoder(
|
PositionalTokenEncoder(
|
||||||
|
@ -175,6 +178,9 @@ class CLIPTextEncoder(Chain):
|
||||||
),
|
),
|
||||||
LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||||
)
|
)
|
||||||
|
if use_quick_gelu:
|
||||||
|
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)):
|
||||||
|
parent.replace(old_module=gelu, new_module=ApproximateGeLU())
|
||||||
|
|
||||||
def encode(self, text: str) -> Tensor:
|
def encode(self, text: str) -> Tensor:
|
||||||
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)
|
tokens = self.tokenizer(text, sequence_length=self.positional_embedding_dim).to(device=self.device)
|
||||||
|
@ -192,6 +198,7 @@ class CLIPTextEncoderL(CLIPTextEncoder):
|
||||||
num_layers=12
|
num_layers=12
|
||||||
num_attention_heads=12
|
num_attention_heads=12
|
||||||
feedforward_dim=3072
|
feedforward_dim=3072
|
||||||
|
use_quick_gelu=True
|
||||||
|
|
||||||
We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation
|
We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation
|
||||||
of OpenAI (https://github.com/openai/CLIP/blob/main/clip/model.py#L166)
|
of OpenAI (https://github.com/openai/CLIP/blob/main/clip/model.py#L166)
|
||||||
|
@ -203,11 +210,10 @@ class CLIPTextEncoderL(CLIPTextEncoder):
|
||||||
num_layers=12,
|
num_layers=12,
|
||||||
num_attention_heads=12,
|
num_attention_heads=12,
|
||||||
feedforward_dim=3072,
|
feedforward_dim=3072,
|
||||||
|
use_quick_gelu=True,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, GeLU)):
|
|
||||||
parent.replace(old_module=gelu, new_module=ApproximateGeLU())
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncoderH(CLIPTextEncoder):
|
class CLIPTextEncoderH(CLIPTextEncoder):
|
||||||
|
|
Loading…
Reference in a new issue