mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
feature: add DINOv2
Co-authored-by: Benjamin Trom <benjamintrom@gmail.com>
This commit is contained in:
parent
e2f2e33add
commit
9337d65e0e
135
scripts/conversion/convert_dinov2.py
Normal file
135
scripts/conversion/convert_dinov2.py
Normal file
|
@ -0,0 +1,135 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
|
||||
"""Convert a DINOv2 weights from facebook to refiners."""
|
||||
# get depth from "blocks" keys
|
||||
depth = max([int(k.split(".")[1]) for k in weights.keys() if k.startswith("blocks.")]) + 1
|
||||
|
||||
# only needed when pre-training
|
||||
del weights["mask_token"]
|
||||
|
||||
# squeeze cls_token and position_embeddings
|
||||
weights["cls_token"] = weights["cls_token"].squeeze(0)
|
||||
weights["pos_embed"] = weights["pos_embed"].squeeze(0)
|
||||
|
||||
rename_keys: list[tuple[str, str]] = [
|
||||
("cls_token", "Concatenate.ClassToken.Parameter.weight"),
|
||||
("pos_embed", "PositionalEncoder.Parameter.weight"),
|
||||
("patch_embed.proj.weight", "Concatenate.PatchEncoder.Conv2d.weight"),
|
||||
("patch_embed.proj.bias", "Concatenate.PatchEncoder.Conv2d.bias"),
|
||||
("norm.weight", "LayerNorm.weight"),
|
||||
("norm.bias", "LayerNorm.bias"),
|
||||
]
|
||||
for i in range(depth):
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.norm1.weight",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerNorm.weight",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.norm1.bias",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerNorm.bias",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.attn.proj.weight",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Linear.weight",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.attn.proj.bias",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Linear.bias",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.ls1.gamma",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerScale.weight",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.norm2.weight",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerNorm.weight",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.norm2.bias",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerNorm.bias",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.mlp.fc1.weight",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_1.weight",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.mlp.fc1.bias",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_1.bias",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.mlp.fc2.weight",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_2.weight",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.mlp.fc2.bias",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_2.bias",
|
||||
),
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"blocks.{i}.ls2.gamma",
|
||||
f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerScale.weight",
|
||||
),
|
||||
)
|
||||
|
||||
if "register_tokens" in weights:
|
||||
weights["register_tokens"] = weights["register_tokens"].squeeze(0)
|
||||
rename_keys.append(("register_tokens", "Registers.Parameter.weight"))
|
||||
|
||||
# rename keys
|
||||
for old_key, new_key in rename_keys:
|
||||
weights[new_key] = weights.pop(old_key)
|
||||
|
||||
# split the qkv weights and biases
|
||||
for i in range(depth):
|
||||
qkv_weight = weights.pop(f"blocks.{i}.attn.qkv.weight")
|
||||
q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0)
|
||||
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_1.weight"] = q_weight
|
||||
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_2.weight"] = k_weight
|
||||
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_3.weight"] = v_weight
|
||||
|
||||
qkv_bias = weights.pop(f"blocks.{i}.attn.qkv.bias")
|
||||
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
|
||||
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_1.bias"] = q_bias
|
||||
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_2.bias"] = k_bias
|
||||
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_3.bias"] = v_bias
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--weights_path", type=str, required=True)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
weights = torch.load(args.weights_path)
|
||||
convert_dinov2_facebook(weights)
|
||||
torch.save(weights, args.output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
29
src/refiners/foundationals/dinov2/__init__.py
Normal file
29
src/refiners/foundationals/dinov2/__init__.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
from .dinov2 import (
|
||||
DINOv2_base,
|
||||
DINOv2_base_reg,
|
||||
DINOv2_large,
|
||||
DINOv2_large_reg,
|
||||
DINOv2_small,
|
||||
DINOv2_small_reg,
|
||||
)
|
||||
from .vit import (
|
||||
ViT,
|
||||
ViT_base,
|
||||
ViT_large,
|
||||
ViT_small,
|
||||
ViT_tiny,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DINOv2_base",
|
||||
"DINOv2_base_reg",
|
||||
"DINOv2_large",
|
||||
"DINOv2_large_reg",
|
||||
"DINOv2_small",
|
||||
"DINOv2_small_reg",
|
||||
"ViT",
|
||||
"ViT_base",
|
||||
"ViT_large",
|
||||
"ViT_small",
|
||||
"ViT_tiny",
|
||||
]
|
145
src/refiners/foundationals/dinov2/dinov2.py
Normal file
145
src/refiners/foundationals/dinov2/dinov2.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
import torch
|
||||
|
||||
from refiners.foundationals.dinov2.vit import ViT
|
||||
|
||||
|
||||
class DINOv2_small(ViT):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=384,
|
||||
patch_size=14,
|
||||
image_size=518,
|
||||
num_layers=12,
|
||||
num_heads=6,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
class DINOv2_base(ViT):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=768,
|
||||
patch_size=14,
|
||||
image_size=518,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
class DINOv2_large(ViT):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=1024,
|
||||
patch_size=14,
|
||||
image_size=518,
|
||||
num_layers=24,
|
||||
num_heads=16,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
# TODO: implement SwiGLU layer
|
||||
# class DINOv2_giant2(ViT):
|
||||
# def __init__(
|
||||
# self,
|
||||
# device: torch.device | str | None = None,
|
||||
# dtype: torch.dtype | None = None,
|
||||
# ) -> None:
|
||||
# super().__init__(
|
||||
# embedding_dim=1536,
|
||||
# patch_size=14,
|
||||
# image_size=518,
|
||||
# num_layers=40,
|
||||
# num_heads=24,
|
||||
# device=device,
|
||||
# dtype=dtype,
|
||||
# )
|
||||
|
||||
|
||||
class DINOv2_small_reg(ViT):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=384,
|
||||
patch_size=14,
|
||||
image_size=518,
|
||||
num_layers=12,
|
||||
num_heads=6,
|
||||
num_registers=4,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
class DINOv2_base_reg(ViT):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=768,
|
||||
patch_size=14,
|
||||
image_size=518,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
num_registers=4,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
class DINOv2_large_reg(ViT):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=1024,
|
||||
patch_size=14,
|
||||
image_size=518,
|
||||
num_layers=24,
|
||||
num_heads=16,
|
||||
num_registers=4,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
# TODO: implement SwiGLU layer
|
||||
# class DINOv2_giant2_reg(ViT):
|
||||
# def __init__(
|
||||
# self,
|
||||
# device: torch.device | str | None = None,
|
||||
# dtype: torch.dtype | None = None,
|
||||
# ) -> None:
|
||||
# super().__init__(
|
||||
# embedding_dim=1536,
|
||||
# patch_size=14,
|
||||
# image_size=518,
|
||||
# num_layers=40,
|
||||
# num_heads=24,
|
||||
# num_registers=4,
|
||||
# device=device,
|
||||
# dtype=dtype,
|
||||
# )
|
372
src/refiners/foundationals/dinov2/vit.py
Normal file
372
src/refiners/foundationals/dinov2/vit.py
Normal file
|
@ -0,0 +1,372 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.layers.activations import Activation
|
||||
|
||||
|
||||
class ClassToken(fl.Chain):
|
||||
"""Learnable token representing the class of the input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
super().__init__(
|
||||
fl.Parameter(
|
||||
*(1, embedding_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class PositionalEncoder(fl.Residual):
|
||||
"""Encode the position of each patch in the input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sequence_length: int,
|
||||
embedding_dim: int,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
self.num_patches = sequence_length
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
super().__init__(
|
||||
fl.Parameter(
|
||||
*(sequence_length, embedding_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LayerScale(fl.WeightedModule):
|
||||
"""Scale the input tensor by a learnable parameter."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
init_value: float = 1.0,
|
||||
dtype: torch.dtype | None = None,
|
||||
device: torch.device | str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
self.register_parameter(
|
||||
name="weight",
|
||||
param=torch.nn.Parameter(
|
||||
torch.full(
|
||||
size=(embedding_dim,),
|
||||
fill_value=init_value,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x * self.weight
|
||||
|
||||
|
||||
class FeedForward(fl.Chain):
|
||||
"""Apply two linear transformations interleaved by an activation function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
feedforward_dim: int,
|
||||
activation: Activation = fl.GeLU, # type: ignore
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.feedforward_dim = feedforward_dim
|
||||
|
||||
super().__init__(
|
||||
fl.Linear(
|
||||
in_features=embedding_dim,
|
||||
out_features=feedforward_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
activation(),
|
||||
fl.Linear(
|
||||
in_features=feedforward_dim,
|
||||
out_features=embedding_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class PatchEncoder(fl.Chain):
|
||||
"""Encode an image into a sequence of patches."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
patch_size: int,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.patch_size = patch_size
|
||||
|
||||
super().__init__(
|
||||
fl.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
), # (N,3,H,W) -> (N,D,P,P)
|
||||
fl.Reshape(out_channels, -1), # (N,D,P,P) -> (N,D,P²)
|
||||
fl.Transpose(1, 2), # (N,D,P²) -> (N,P²,D)
|
||||
)
|
||||
|
||||
|
||||
class TransformerLayer(fl.Chain):
|
||||
"""Apply a multi-head self-attention mechanism to the input tensor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
norm_eps: float,
|
||||
mlp_ratio: int,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.norm_eps = norm_eps
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
super().__init__(
|
||||
fl.Residual(
|
||||
fl.LayerNorm(
|
||||
normalized_shape=embedding_dim,
|
||||
eps=norm_eps,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.SelfAttention(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=num_heads,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
LayerScale(
|
||||
embedding_dim=embedding_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
),
|
||||
fl.Residual(
|
||||
fl.LayerNorm(
|
||||
normalized_shape=embedding_dim,
|
||||
eps=norm_eps,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
FeedForward(
|
||||
embedding_dim=embedding_dim,
|
||||
feedforward_dim=embedding_dim * mlp_ratio,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
LayerScale(
|
||||
embedding_dim=embedding_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Transformer(fl.Chain):
|
||||
"""Alias for a Chain of TransformerLayer."""
|
||||
|
||||
|
||||
class Registers(fl.Concatenate):
|
||||
"""Insert register tokens between CLS token and patches."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_registers: int,
|
||||
embedding_dim: int,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
self.num_registers = num_registers
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
super().__init__(
|
||||
fl.Slicing(dim=1, end=1),
|
||||
fl.Parameter(
|
||||
*(num_registers, embedding_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Slicing(dim=1, start=1),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
|
||||
class ViT(fl.Chain):
|
||||
"""Vision Transformer (ViT).
|
||||
|
||||
see https://arxiv.org/abs/2010.11929v2
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 768,
|
||||
patch_size: int = 16,
|
||||
image_size: int = 224,
|
||||
num_layers: int = 12,
|
||||
num_heads: int = 12,
|
||||
norm_eps: float = 1e-6,
|
||||
mlp_ratio: int = 4,
|
||||
num_registers: int = 0,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
num_patches = image_size // patch_size
|
||||
self.embedding_dim = embedding_dim
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.norm_eps = norm_eps
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.num_registers = num_registers
|
||||
|
||||
super().__init__(
|
||||
fl.Concatenate(
|
||||
ClassToken(
|
||||
embedding_dim=embedding_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
PatchEncoder(
|
||||
in_channels=3,
|
||||
out_channels=embedding_dim,
|
||||
patch_size=patch_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
dim=1,
|
||||
),
|
||||
PositionalEncoder(
|
||||
sequence_length=num_patches**2 + 1,
|
||||
embedding_dim=embedding_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
Transformer(
|
||||
TransformerLayer(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=num_heads,
|
||||
norm_eps=norm_eps,
|
||||
mlp_ratio=mlp_ratio,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
),
|
||||
fl.LayerNorm(
|
||||
normalized_shape=embedding_dim,
|
||||
eps=norm_eps,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
if self.num_registers > 0:
|
||||
registers = Registers(
|
||||
num_registers=num_registers,
|
||||
embedding_dim=embedding_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.insert_before_type(Transformer, registers)
|
||||
|
||||
|
||||
class ViT_tiny(ViT):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=192,
|
||||
patch_size=16,
|
||||
image_size=224,
|
||||
num_layers=12,
|
||||
num_heads=3,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
class ViT_small(ViT):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=384,
|
||||
patch_size=16,
|
||||
image_size=224,
|
||||
num_layers=12,
|
||||
num_heads=6,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
class ViT_base(ViT):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=768,
|
||||
patch_size=16,
|
||||
image_size=224,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
class ViT_large(ViT):
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=1024,
|
||||
patch_size=16,
|
||||
image_size=224,
|
||||
num_layers=24,
|
||||
num_heads=16,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
Loading…
Reference in a new issue