mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
remove unused ViT variations
This commit is contained in:
parent
849c0058df
commit
ae19892d1d
|
@ -6,13 +6,7 @@ from .dinov2 import (
|
||||||
DINOv2_small,
|
DINOv2_small,
|
||||||
DINOv2_small_reg,
|
DINOv2_small_reg,
|
||||||
)
|
)
|
||||||
from .vit import (
|
from .vit import ViT
|
||||||
ViT,
|
|
||||||
ViT_base,
|
|
||||||
ViT_large,
|
|
||||||
ViT_small,
|
|
||||||
ViT_tiny,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DINOv2_base",
|
"DINOv2_base",
|
||||||
|
@ -22,8 +16,4 @@ __all__ = [
|
||||||
"DINOv2_small",
|
"DINOv2_small",
|
||||||
"DINOv2_small_reg",
|
"DINOv2_small_reg",
|
||||||
"ViT",
|
"ViT",
|
||||||
"ViT_base",
|
|
||||||
"ViT_large",
|
|
||||||
"ViT_small",
|
|
||||||
"ViT_tiny",
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -305,71 +305,3 @@ class ViT(fl.Chain):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.insert_before_type(Transformer, registers)
|
self.insert_before_type(Transformer, registers)
|
||||||
|
|
||||||
|
|
||||||
class ViT_tiny(ViT):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device: torch.device | str | None = None,
|
|
||||||
dtype: torch.dtype | None = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
embedding_dim=192,
|
|
||||||
patch_size=16,
|
|
||||||
image_size=224,
|
|
||||||
num_layers=12,
|
|
||||||
num_heads=3,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ViT_small(ViT):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device: torch.device | str | None = None,
|
|
||||||
dtype: torch.dtype | None = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
embedding_dim=384,
|
|
||||||
patch_size=16,
|
|
||||||
image_size=224,
|
|
||||||
num_layers=12,
|
|
||||||
num_heads=6,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ViT_base(ViT):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device: torch.device | str | None = None,
|
|
||||||
dtype: torch.dtype | None = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
embedding_dim=768,
|
|
||||||
patch_size=16,
|
|
||||||
image_size=224,
|
|
||||||
num_layers=12,
|
|
||||||
num_heads=12,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ViT_large(ViT):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device: torch.device | str | None = None,
|
|
||||||
dtype: torch.dtype | None = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
embedding_dim=1024,
|
|
||||||
patch_size=16,
|
|
||||||
image_size=224,
|
|
||||||
num_layers=24,
|
|
||||||
num_heads=16,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
Loading…
Reference in a new issue