remove unused ViT variations

This commit is contained in:
Pierre Chapuis 2024-01-29 16:15:29 +01:00 committed by Cédric Deltheil
parent 849c0058df
commit ae19892d1d
2 changed files with 1 additions and 79 deletions

View file

@ -6,13 +6,7 @@ from .dinov2 import (
DINOv2_small, DINOv2_small,
DINOv2_small_reg, DINOv2_small_reg,
) )
from .vit import ( from .vit import ViT
ViT,
ViT_base,
ViT_large,
ViT_small,
ViT_tiny,
)
__all__ = [ __all__ = [
"DINOv2_base", "DINOv2_base",
@ -22,8 +16,4 @@ __all__ = [
"DINOv2_small", "DINOv2_small",
"DINOv2_small_reg", "DINOv2_small_reg",
"ViT", "ViT",
"ViT_base",
"ViT_large",
"ViT_small",
"ViT_tiny",
] ]

View file

@ -305,71 +305,3 @@ class ViT(fl.Chain):
dtype=dtype, dtype=dtype,
) )
self.insert_before_type(Transformer, registers) self.insert_before_type(Transformer, registers)
class ViT_tiny(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=192,
patch_size=16,
image_size=224,
num_layers=12,
num_heads=3,
device=device,
dtype=dtype,
)
class ViT_small(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=384,
patch_size=16,
image_size=224,
num_layers=12,
num_heads=6,
device=device,
dtype=dtype,
)
class ViT_base(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=768,
patch_size=16,
image_size=224,
num_layers=12,
num_heads=12,
device=device,
dtype=dtype,
)
class ViT_large(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=1024,
patch_size=16,
image_size=224,
num_layers=24,
num_heads=16,
device=device,
dtype=dtype,
)