feature: add DINOv2

Co-authored-by: Benjamin Trom <benjamintrom@gmail.com>
This commit is contained in:
Laureηt 2023-12-14 17:27:32 +01:00 committed by GitHub
parent e2f2e33add
commit 9337d65e0e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 681 additions and 0 deletions

View file

@ -0,0 +1,135 @@
import argparse
import torch
def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
"""Convert a DINOv2 weights from facebook to refiners."""
# get depth from "blocks" keys
depth = max([int(k.split(".")[1]) for k in weights.keys() if k.startswith("blocks.")]) + 1
# only needed when pre-training
del weights["mask_token"]
# squeeze cls_token and position_embeddings
weights["cls_token"] = weights["cls_token"].squeeze(0)
weights["pos_embed"] = weights["pos_embed"].squeeze(0)
rename_keys: list[tuple[str, str]] = [
("cls_token", "Concatenate.ClassToken.Parameter.weight"),
("pos_embed", "PositionalEncoder.Parameter.weight"),
("patch_embed.proj.weight", "Concatenate.PatchEncoder.Conv2d.weight"),
("patch_embed.proj.bias", "Concatenate.PatchEncoder.Conv2d.bias"),
("norm.weight", "LayerNorm.weight"),
("norm.bias", "LayerNorm.bias"),
]
for i in range(depth):
rename_keys.append(
(
f"blocks.{i}.norm1.weight",
f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerNorm.weight",
),
)
rename_keys.append(
(
f"blocks.{i}.norm1.bias",
f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerNorm.bias",
),
)
rename_keys.append(
(
f"blocks.{i}.attn.proj.weight",
f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Linear.weight",
),
)
rename_keys.append(
(
f"blocks.{i}.attn.proj.bias",
f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Linear.bias",
),
)
rename_keys.append(
(
f"blocks.{i}.ls1.gamma",
f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerScale.weight",
),
)
rename_keys.append(
(
f"blocks.{i}.norm2.weight",
f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerNorm.weight",
),
)
rename_keys.append(
(
f"blocks.{i}.norm2.bias",
f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerNorm.bias",
),
)
rename_keys.append(
(
f"blocks.{i}.mlp.fc1.weight",
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_1.weight",
),
)
rename_keys.append(
(
f"blocks.{i}.mlp.fc1.bias",
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_1.bias",
),
)
rename_keys.append(
(
f"blocks.{i}.mlp.fc2.weight",
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_2.weight",
),
)
rename_keys.append(
(
f"blocks.{i}.mlp.fc2.bias",
f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_2.bias",
),
)
rename_keys.append(
(
f"blocks.{i}.ls2.gamma",
f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerScale.weight",
),
)
if "register_tokens" in weights:
weights["register_tokens"] = weights["register_tokens"].squeeze(0)
rename_keys.append(("register_tokens", "Registers.Parameter.weight"))
# rename keys
for old_key, new_key in rename_keys:
weights[new_key] = weights.pop(old_key)
# split the qkv weights and biases
for i in range(depth):
qkv_weight = weights.pop(f"blocks.{i}.attn.qkv.weight")
q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0)
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_1.weight"] = q_weight
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_2.weight"] = k_weight
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_3.weight"] = v_weight
qkv_bias = weights.pop(f"blocks.{i}.attn.qkv.bias")
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_1.bias"] = q_bias
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_2.bias"] = k_bias
weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_3.bias"] = v_bias
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--weights_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
args = parser.parse_args()
weights = torch.load(args.weights_path)
convert_dinov2_facebook(weights)
torch.save(weights, args.output_path)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,29 @@
from .dinov2 import (
DINOv2_base,
DINOv2_base_reg,
DINOv2_large,
DINOv2_large_reg,
DINOv2_small,
DINOv2_small_reg,
)
from .vit import (
ViT,
ViT_base,
ViT_large,
ViT_small,
ViT_tiny,
)
__all__ = [
"DINOv2_base",
"DINOv2_base_reg",
"DINOv2_large",
"DINOv2_large_reg",
"DINOv2_small",
"DINOv2_small_reg",
"ViT",
"ViT_base",
"ViT_large",
"ViT_small",
"ViT_tiny",
]

View file

@ -0,0 +1,145 @@
import torch
from refiners.foundationals.dinov2.vit import ViT
class DINOv2_small(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=384,
patch_size=14,
image_size=518,
num_layers=12,
num_heads=6,
device=device,
dtype=dtype,
)
class DINOv2_base(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=768,
patch_size=14,
image_size=518,
num_layers=12,
num_heads=12,
device=device,
dtype=dtype,
)
class DINOv2_large(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=1024,
patch_size=14,
image_size=518,
num_layers=24,
num_heads=16,
device=device,
dtype=dtype,
)
# TODO: implement SwiGLU layer
# class DINOv2_giant2(ViT):
# def __init__(
# self,
# device: torch.device | str | None = None,
# dtype: torch.dtype | None = None,
# ) -> None:
# super().__init__(
# embedding_dim=1536,
# patch_size=14,
# image_size=518,
# num_layers=40,
# num_heads=24,
# device=device,
# dtype=dtype,
# )
class DINOv2_small_reg(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=384,
patch_size=14,
image_size=518,
num_layers=12,
num_heads=6,
num_registers=4,
device=device,
dtype=dtype,
)
class DINOv2_base_reg(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=768,
patch_size=14,
image_size=518,
num_layers=12,
num_heads=12,
num_registers=4,
device=device,
dtype=dtype,
)
class DINOv2_large_reg(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=1024,
patch_size=14,
image_size=518,
num_layers=24,
num_heads=16,
num_registers=4,
device=device,
dtype=dtype,
)
# TODO: implement SwiGLU layer
# class DINOv2_giant2_reg(ViT):
# def __init__(
# self,
# device: torch.device | str | None = None,
# dtype: torch.dtype | None = None,
# ) -> None:
# super().__init__(
# embedding_dim=1536,
# patch_size=14,
# image_size=518,
# num_layers=40,
# num_heads=24,
# num_registers=4,
# device=device,
# dtype=dtype,
# )

View file

@ -0,0 +1,372 @@
import torch
from torch import Tensor
import refiners.fluxion.layers as fl
from refiners.fluxion.layers.activations import Activation
class ClassToken(fl.Chain):
"""Learnable token representing the class of the input."""
def __init__(
self,
embedding_dim: int,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
self.embedding_dim = embedding_dim
super().__init__(
fl.Parameter(
*(1, embedding_dim),
device=device,
dtype=dtype,
),
)
class PositionalEncoder(fl.Residual):
"""Encode the position of each patch in the input."""
def __init__(
self,
sequence_length: int,
embedding_dim: int,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
self.num_patches = sequence_length
self.embedding_dim = embedding_dim
super().__init__(
fl.Parameter(
*(sequence_length, embedding_dim),
device=device,
dtype=dtype,
),
)
class LayerScale(fl.WeightedModule):
"""Scale the input tensor by a learnable parameter."""
def __init__(
self,
embedding_dim: int,
init_value: float = 1.0,
dtype: torch.dtype | None = None,
device: torch.device | str | None = None,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.register_parameter(
name="weight",
param=torch.nn.Parameter(
torch.full(
size=(embedding_dim,),
fill_value=init_value,
dtype=dtype,
device=device,
),
),
)
def forward(self, x: Tensor) -> Tensor:
return x * self.weight
class FeedForward(fl.Chain):
"""Apply two linear transformations interleaved by an activation function."""
def __init__(
self,
embedding_dim: int,
feedforward_dim: int,
activation: Activation = fl.GeLU, # type: ignore
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim
super().__init__(
fl.Linear(
in_features=embedding_dim,
out_features=feedforward_dim,
device=device,
dtype=dtype,
),
activation(),
fl.Linear(
in_features=feedforward_dim,
out_features=embedding_dim,
device=device,
dtype=dtype,
),
)
class PatchEncoder(fl.Chain):
"""Encode an image into a sequence of patches."""
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_size = patch_size
super().__init__(
fl.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=patch_size,
stride=patch_size,
device=device,
dtype=dtype,
), # (N,3,H,W) -> (N,D,P,P)
fl.Reshape(out_channels, -1), # (N,D,P,P) -> (N,D,P²)
fl.Transpose(1, 2), # (N,D,P²) -> (N,P²,D)
)
class TransformerLayer(fl.Chain):
"""Apply a multi-head self-attention mechanism to the input tensor."""
def __init__(
self,
embedding_dim: int,
num_heads: int,
norm_eps: float,
mlp_ratio: int,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.norm_eps = norm_eps
self.mlp_ratio = mlp_ratio
super().__init__(
fl.Residual(
fl.LayerNorm(
normalized_shape=embedding_dim,
eps=norm_eps,
device=device,
dtype=dtype,
),
fl.SelfAttention(
embedding_dim=embedding_dim,
num_heads=num_heads,
device=device,
dtype=dtype,
),
LayerScale(
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
),
fl.Residual(
fl.LayerNorm(
normalized_shape=embedding_dim,
eps=norm_eps,
device=device,
dtype=dtype,
),
FeedForward(
embedding_dim=embedding_dim,
feedforward_dim=embedding_dim * mlp_ratio,
device=device,
dtype=dtype,
),
LayerScale(
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
),
)
class Transformer(fl.Chain):
"""Alias for a Chain of TransformerLayer."""
class Registers(fl.Concatenate):
"""Insert register tokens between CLS token and patches."""
def __init__(
self,
num_registers: int,
embedding_dim: int,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
self.num_registers = num_registers
self.embedding_dim = embedding_dim
super().__init__(
fl.Slicing(dim=1, end=1),
fl.Parameter(
*(num_registers, embedding_dim),
device=device,
dtype=dtype,
),
fl.Slicing(dim=1, start=1),
dim=1,
)
class ViT(fl.Chain):
"""Vision Transformer (ViT).
see https://arxiv.org/abs/2010.11929v2
"""
def __init__(
self,
embedding_dim: int = 768,
patch_size: int = 16,
image_size: int = 224,
num_layers: int = 12,
num_heads: int = 12,
norm_eps: float = 1e-6,
mlp_ratio: int = 4,
num_registers: int = 0,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
num_patches = image_size // patch_size
self.embedding_dim = embedding_dim
self.patch_size = patch_size
self.image_size = image_size
self.num_layers = num_layers
self.num_heads = num_heads
self.norm_eps = norm_eps
self.mlp_ratio = mlp_ratio
self.num_registers = num_registers
super().__init__(
fl.Concatenate(
ClassToken(
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
PatchEncoder(
in_channels=3,
out_channels=embedding_dim,
patch_size=patch_size,
device=device,
dtype=dtype,
),
dim=1,
),
PositionalEncoder(
sequence_length=num_patches**2 + 1,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
Transformer(
TransformerLayer(
embedding_dim=embedding_dim,
num_heads=num_heads,
norm_eps=norm_eps,
mlp_ratio=mlp_ratio,
device=device,
dtype=dtype,
)
for _ in range(num_layers)
),
fl.LayerNorm(
normalized_shape=embedding_dim,
eps=norm_eps,
device=device,
dtype=dtype,
),
)
if self.num_registers > 0:
registers = Registers(
num_registers=num_registers,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
)
self.insert_before_type(Transformer, registers)
class ViT_tiny(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=192,
patch_size=16,
image_size=224,
num_layers=12,
num_heads=3,
device=device,
dtype=dtype,
)
class ViT_small(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=384,
patch_size=16,
image_size=224,
num_layers=12,
num_heads=6,
device=device,
dtype=dtype,
)
class ViT_base(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=768,
patch_size=16,
image_size=224,
num_layers=12,
num_heads=12,
device=device,
dtype=dtype,
)
class ViT_large(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=1024,
patch_size=16,
image_size=224,
num_layers=24,
num_heads=16,
device=device,
dtype=dtype,
)