From ae19892d1d86cc347434b3665309af6b40fe1c82 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Mon, 29 Jan 2024 16:15:29 +0100 Subject: [PATCH] remove unused ViT variations --- src/refiners/foundationals/dinov2/__init__.py | 12 +--- src/refiners/foundationals/dinov2/vit.py | 68 ------------------- 2 files changed, 1 insertion(+), 79 deletions(-) diff --git a/src/refiners/foundationals/dinov2/__init__.py b/src/refiners/foundationals/dinov2/__init__.py index 8e802bb..91cfa79 100644 --- a/src/refiners/foundationals/dinov2/__init__.py +++ b/src/refiners/foundationals/dinov2/__init__.py @@ -6,13 +6,7 @@ from .dinov2 import ( DINOv2_small, DINOv2_small_reg, ) -from .vit import ( - ViT, - ViT_base, - ViT_large, - ViT_small, - ViT_tiny, -) +from .vit import ViT __all__ = [ "DINOv2_base", @@ -22,8 +16,4 @@ __all__ = [ "DINOv2_small", "DINOv2_small_reg", "ViT", - "ViT_base", - "ViT_large", - "ViT_small", - "ViT_tiny", ] diff --git a/src/refiners/foundationals/dinov2/vit.py b/src/refiners/foundationals/dinov2/vit.py index 2045fd3..eb08ee2 100644 --- a/src/refiners/foundationals/dinov2/vit.py +++ b/src/refiners/foundationals/dinov2/vit.py @@ -305,71 +305,3 @@ class ViT(fl.Chain): dtype=dtype, ) self.insert_before_type(Transformer, registers) - - -class ViT_tiny(ViT): - def __init__( - self, - device: torch.device | str | None = None, - dtype: torch.dtype | None = None, - ) -> None: - super().__init__( - embedding_dim=192, - patch_size=16, - image_size=224, - num_layers=12, - num_heads=3, - device=device, - dtype=dtype, - ) - - -class ViT_small(ViT): - def __init__( - self, - device: torch.device | str | None = None, - dtype: torch.dtype | None = None, - ) -> None: - super().__init__( - embedding_dim=384, - patch_size=16, - image_size=224, - num_layers=12, - num_heads=6, - device=device, - dtype=dtype, - ) - - -class ViT_base(ViT): - def __init__( - self, - device: torch.device | str | None = None, - dtype: torch.dtype | None = None, - ) -> None: - super().__init__( - embedding_dim=768, - patch_size=16, - image_size=224, - num_layers=12, - num_heads=12, - device=device, - dtype=dtype, - ) - - -class ViT_large(ViT): - def __init__( - self, - device: torch.device | str | None = None, - dtype: torch.dtype | None = None, - ) -> None: - super().__init__( - embedding_dim=1024, - patch_size=16, - image_size=224, - num_layers=24, - num_heads=16, - device=device, - dtype=dtype, - )