mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
remove unused ViT variations
This commit is contained in:
parent
849c0058df
commit
ae19892d1d
|
@ -6,13 +6,7 @@ from .dinov2 import (
|
|||
DINOv2_small,
|
||||
DINOv2_small_reg,
|
||||
)
|
||||
from .vit import (
|
||||
ViT,
|
||||
ViT_base,
|
||||
ViT_large,
|
||||
ViT_small,
|
||||
ViT_tiny,
|
||||
)
|
||||
from .vit import ViT
|
||||
|
||||
__all__ = [
|
||||
"DINOv2_base",
|
||||
|
@ -22,8 +16,4 @@ __all__ = [
|
|||
"DINOv2_small",
|
||||
"DINOv2_small_reg",
|
||||
"ViT",
|
||||
"ViT_base",
|
||||
"ViT_large",
|
||||
"ViT_small",
|
||||
"ViT_tiny",
|
||||
]
|
||||
|
|
|
@ -305,71 +305,3 @@ class ViT(fl.Chain):
|
|||
dtype=dtype,
|
||||
)
|
||||
self.insert_before_type(Transformer, registers)
|
||||
|
||||
|
||||
class ViT_tiny(ViT):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=192,
|
||||
patch_size=16,
|
||||
image_size=224,
|
||||
num_layers=12,
|
||||
num_heads=3,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
class ViT_small(ViT):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=384,
|
||||
patch_size=16,
|
||||
image_size=224,
|
||||
num_layers=12,
|
||||
num_heads=6,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
class ViT_base(ViT):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=768,
|
||||
patch_size=16,
|
||||
image_size=224,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
class ViT_large(ViT):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=1024,
|
||||
patch_size=16,
|
||||
image_size=224,
|
||||
num_layers=24,
|
||||
num_heads=16,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue