(doc/foundationals) add DINOv2, related docstrings

This commit is contained in:
Laurent 2024-02-02 13:12:13 +00:00 committed by Laureηt
parent fc7b4dd62d
commit 3910845e29
2 changed files with 137 additions and 2 deletions

View file

@ -7,11 +7,30 @@ from refiners.foundationals.dinov2.vit import ViT
class DINOv2_small(ViT): class DINOv2_small(ViT):
"""DINOv2 small model.
See [[arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)
for more details.
Attributes:
embedding_dim (int): 384
patch_size (int): 14
image_size (int): 518
num_layers (int): 12
num_heads (int): 6
"""
def __init__( def __init__(
self, self,
device: torch.device | str | None = None, device: torch.device | str | None = None,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
) -> None: ) -> None:
"""Initialize DINOv2 small model.
Args:
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__( super().__init__(
embedding_dim=384, embedding_dim=384,
patch_size=14, patch_size=14,
@ -24,11 +43,30 @@ class DINOv2_small(ViT):
class DINOv2_base(ViT): class DINOv2_base(ViT):
"""DINOv2 base model.
See [[arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)
for more details.
Attributes:
embedding_dim (int): 768
patch_size (int): 14
image_size (int): 518
num_layers (int): 12
num_heads (int): 12
"""
def __init__( def __init__(
self, self,
device: torch.device | str | None = None, device: torch.device | str | None = None,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
) -> None: ) -> None:
"""Initialize DINOv2 base model.
Args:
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__( super().__init__(
embedding_dim=768, embedding_dim=768,
patch_size=14, patch_size=14,
@ -41,11 +79,30 @@ class DINOv2_base(ViT):
class DINOv2_large(ViT): class DINOv2_large(ViT):
"""DINOv2 large model.
See [[arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)
for more details.
Attributes:
embedding_dim (int): 1024
patch_size (int): 14
image_size (int): 518
num_layers (int): 24
num_heads (int): 16
"""
def __init__( def __init__(
self, self,
device: torch.device | str | None = None, device: torch.device | str | None = None,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
) -> None: ) -> None:
"""Initialize DINOv2 large model.
Args:
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__( super().__init__(
embedding_dim=1024, embedding_dim=1024,
patch_size=14, patch_size=14,
@ -76,11 +133,32 @@ class DINOv2_large(ViT):
class DINOv2_small_reg(ViT): class DINOv2_small_reg(ViT):
"""DINOv2 small model with register.
See [[arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)
and [[arXiv:2309.16588] Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588)
for more details.
Attributes:
embedding_dim (int): 384
patch_size (int): 14
image_size (int): 518
num_layers (int): 12
num_heads (int): 6
num_registers (int): 4
"""
def __init__( def __init__(
self, self,
device: torch.device | str | None = None, device: torch.device | str | None = None,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
) -> None: ) -> None:
"""Initialize DINOv2 small model with register.
Args:
device (torch.device | str | None): The PyTorch device to use.
dtype (torch.dtype | None): The PyTorch data type to use.
"""
super().__init__( super().__init__(
embedding_dim=384, embedding_dim=384,
patch_size=14, patch_size=14,
@ -94,11 +172,32 @@ class DINOv2_small_reg(ViT):
class DINOv2_base_reg(ViT): class DINOv2_base_reg(ViT):
"""DINOv2 base model with register.
See [[arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)
and [[arXiv:2309.16588] Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588)
for more details.
Attributes:
embedding_dim (int): 768
patch_size (int): 14
image_size (int): 518
num_layers (int): 12
num_heads (int): 12
num_registers (int): 4
"""
def __init__( def __init__(
self, self,
device: torch.device | str | None = None, device: torch.device | str | None = None,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
) -> None: ) -> None:
"""Initialize DINOv2 base model with register.
Args:
device (torch.device | str | None): The PyTorch device to use.
dtype (torch.dtype | None): The PyTorch data type to use.
"""
super().__init__( super().__init__(
embedding_dim=768, embedding_dim=768,
patch_size=14, patch_size=14,
@ -112,11 +211,32 @@ class DINOv2_base_reg(ViT):
class DINOv2_large_reg(ViT): class DINOv2_large_reg(ViT):
"""DINOv2 large model with register.
See [[arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)
and [[arXiv:2309.16588] Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588)
for more details.
Attributes:
embedding_dim (int): 1024
patch_size (int): 14
image_size (int): 518
num_layers (int): 24
num_heads (int): 16
num_registers (int): 4
"""
def __init__( def __init__(
self, self,
device: torch.device | str | None = None, device: torch.device | str | None = None,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
) -> None: ) -> None:
"""Initialize DINOv2 large model with register.
Args:
device (torch.device | str | None): The PyTorch device to use.
dtype (torch.dtype | None): The PyTorch data type to use.
"""
super().__init__( super().__init__(
embedding_dim=1024, embedding_dim=1024,
patch_size=14, patch_size=14,

View file

@ -227,9 +227,10 @@ class Registers(fl.Concatenate):
class ViT(fl.Chain): class ViT(fl.Chain):
"""Vision Transformer (ViT). """Vision Transformer (ViT) model.
see https://arxiv.org/abs/2010.11929v2 See [[arXiv:2010.11929] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)
for more details.
""" """
def __init__( def __init__(
@ -245,6 +246,20 @@ class ViT(fl.Chain):
device: torch.device | str | None = None, device: torch.device | str | None = None,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
) -> None: ) -> None:
"""Initialize a Vision Transformer (ViT) model.
Args:
embedding_dim: The dimension of the embedding.
patch_size: The size of the patches.
image_size: The size of the input image.
num_layers: The number of layers.
num_heads: The number of heads.
norm_eps: The epsilon value for normalization.
mlp_ratio: The ratio for the multi-layer perceptron (MLP).
num_registers: The number of registers.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
num_patches = image_size // patch_size num_patches = image_size // patch_size
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.patch_size = patch_size self.patch_size = patch_size