(doc/foundationals) add CLIP, related docstrings

This commit is contained in:
Laurent 2024-02-02 13:27:47 +00:00 committed by Laureηt
parent 8befede3cf
commit 7bc5ce35d2
3 changed files with 155 additions and 21 deletions

View file

@ -0,0 +1,21 @@
from refiners.foundationals.clip.image_encoder import (
CLIPImageEncoder,
CLIPImageEncoderG,
CLIPImageEncoderH,
)
from refiners.foundationals.clip.text_encoder import (
CLIPTextEncoder,
CLIPTextEncoderG,
CLIPTextEncoderH,
CLIPTextEncoderL,
)
__all__ = [
"CLIPTextEncoder",
"CLIPTextEncoderL",
"CLIPTextEncoderH",
"CLIPTextEncoderG",
"CLIPImageEncoder",
"CLIPImageEncoderG",
"CLIPImageEncoderH",
]

View file

@ -108,6 +108,12 @@ class ViTEmbeddings(fl.Chain):
class CLIPImageEncoder(fl.Chain):
"""Contrastive Language-Image Pretraining (CLIP) image encoder.
See [[arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)
for more details.
"""
def __init__(
self,
image_size: int = 224,
@ -121,6 +127,20 @@ class CLIPImageEncoder(fl.Chain):
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
"""Initialize a CLIP image encoder.
Args:
image_size: The size of the input image.
embedding_dim: The dimension of the embedding.
output_dim: The dimension of the output.
patch_size: The size of the patches.
num_layers: The number of layers.
num_attention_heads: The number of attention heads.
feedforward_dim: The dimension of the feedforward layer.
layer_norm_eps: The epsilon value for normalization.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
self.image_size = image_size
self.embedding_dim = embedding_dim
self.output_dim = output_dim
@ -152,7 +172,27 @@ class CLIPImageEncoder(fl.Chain):
class CLIPImageEncoderH(CLIPImageEncoder):
"""CLIP huge image encoder.
See [[arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)
for more details.
Attributes:
embedding_dim (int): 1280
output_dim (int): 1024
patch_size (int): 14
num_layers (int): 32
num_attention_heads (int): 16
feedforward_dim (int): 5120
"""
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
"""Initialize CLIP huge image encoder.
Args:
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__(
embedding_dim=1280,
output_dim=1024,
@ -166,7 +206,27 @@ class CLIPImageEncoderH(CLIPImageEncoder):
class CLIPImageEncoderG(CLIPImageEncoder):
"""CLIP giant image encoder.
See [[arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)
for more details.
Attributes:
embedding_dim (int): 1664
output_dim (int): 1280
patch_size (int): 14
num_layers (int): 48
num_attention_heads (int): 16
feedforward_dim (int): 8192
"""
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
"""Initialize CLIP giant image encoder.
Args:
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__(
embedding_dim=1664,
output_dim=1280,

View file

@ -71,6 +71,12 @@ class TransformerLayer(fl.Chain):
class CLIPTextEncoder(fl.Chain):
"""Contrastive Language-Image Pretraining (CLIP) text encoder.
See [[arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)
for more details.
"""
def __init__(
self,
embedding_dim: int = 768,
@ -85,6 +91,21 @@ class CLIPTextEncoder(fl.Chain):
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
"""Initialize CLIP text encoder.
Args:
embedding_dim: The embedding dimension.
max_sequence_length: The maximum sequence length.
vocabulary_size: The vocabulary size.
num_layers: The number of layers.
num_attention_heads: The number of attention heads.
feedforward_dim: The feedforward dimension.
layer_norm_eps: The epsilon value for layer normalization.
use_quick_gelu: Whether to use the quick GeLU activation function.
tokenizer: The tokenizer.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
self.embedding_dim = embedding_dim
self.max_sequence_length = max_sequence_length
self.vocabulary_size = vocabulary_size
@ -129,19 +150,30 @@ class CLIPTextEncoder(fl.Chain):
class CLIPTextEncoderL(CLIPTextEncoder):
"""
CLIPTextEncoderL is the CLIP text encoder with the following parameters:
embedding_dim=768
num_layers=12
num_attention_heads=12
feedforward_dim=3072
use_quick_gelu=True
"""CLIP large text encoder.
We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation
of OpenAI (https://github.com/openai/CLIP/blob/main/clip/model.py#L166)
Note:
We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation
of OpenAI (https://github.com/openai/CLIP/blob/main/clip/model.py#L166)
See [[arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)
for more details.
Attributes:
embedding_dim (int): 768
num_layers (int): 12
num_attention_heads (int): 12
feedforward_dim (int): 3072
use_quick_gelu (bool): True
"""
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
"""Initialize CLIP large text encoder.
Args:
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__(
embedding_dim=768,
num_layers=12,
@ -154,15 +186,25 @@ class CLIPTextEncoderL(CLIPTextEncoder):
class CLIPTextEncoderH(CLIPTextEncoder):
"""
CLIPTextEncoderH is the CLIP text encoder with the following parameters:
embedding_dim=1024
num_layers=23
num_attention_heads=16
feedforward_dim=4096
"""CLIP huge text encoder.
See [[arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)
for more details.
Attributes:
embedding_dim (int): 1024
num_layers (int): 23
num_attention_heads (int): 16
feedforward_dim (int): 4096
"""
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
"""Initialize CLIP huge text encoder.
Args:
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__(
embedding_dim=1024,
num_layers=23,
@ -174,15 +216,26 @@ class CLIPTextEncoderH(CLIPTextEncoder):
class CLIPTextEncoderG(CLIPTextEncoder):
"""
CLIPTextEncoderG is the CLIP text encoder with the following parameters:
embedding_dim=1280
num_layers=32
num_attention_heads=16
feedforward_dim=5120
"""CLIP giant text encoder.
See [[arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)
for more details.
Attributes:
embedding_dim (int): 1280
num_layers (int): 32
num_attention_heads (int): 20
feedforward_dim (int): 5120
tokenizer (CLIPTokenizer): CLIPTokenizer(pad_token_id=0)
"""
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
"""Initialize CLIP giant text encoder.
Args:
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
tokenizer = CLIPTokenizer(pad_token_id=0)
super().__init__(
embedding_dim=1280,