From 9337d65e0e78d49196b0ce0cd81276652bd1c125 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Thu, 14 Dec 2023 17:27:32 +0100 Subject: [PATCH] feature: add DINOv2 Co-authored-by: Benjamin Trom --- scripts/conversion/convert_dinov2.py | 135 +++++++ src/refiners/foundationals/dinov2/__init__.py | 29 ++ src/refiners/foundationals/dinov2/dinov2.py | 145 +++++++ src/refiners/foundationals/dinov2/vit.py | 372 ++++++++++++++++++ 4 files changed, 681 insertions(+) create mode 100644 scripts/conversion/convert_dinov2.py create mode 100644 src/refiners/foundationals/dinov2/__init__.py create mode 100644 src/refiners/foundationals/dinov2/dinov2.py create mode 100644 src/refiners/foundationals/dinov2/vit.py diff --git a/scripts/conversion/convert_dinov2.py b/scripts/conversion/convert_dinov2.py new file mode 100644 index 0000000..98dafb5 --- /dev/null +++ b/scripts/conversion/convert_dinov2.py @@ -0,0 +1,135 @@ +import argparse + +import torch + + +def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None: + """Convert a DINOv2 weights from facebook to refiners.""" + # get depth from "blocks" keys + depth = max([int(k.split(".")[1]) for k in weights.keys() if k.startswith("blocks.")]) + 1 + + # only needed when pre-training + del weights["mask_token"] + + # squeeze cls_token and position_embeddings + weights["cls_token"] = weights["cls_token"].squeeze(0) + weights["pos_embed"] = weights["pos_embed"].squeeze(0) + + rename_keys: list[tuple[str, str]] = [ + ("cls_token", "Concatenate.ClassToken.Parameter.weight"), + ("pos_embed", "PositionalEncoder.Parameter.weight"), + ("patch_embed.proj.weight", "Concatenate.PatchEncoder.Conv2d.weight"), + ("patch_embed.proj.bias", "Concatenate.PatchEncoder.Conv2d.bias"), + ("norm.weight", "LayerNorm.weight"), + ("norm.bias", "LayerNorm.bias"), + ] + for i in range(depth): + rename_keys.append( + ( + f"blocks.{i}.norm1.weight", + f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerNorm.weight", + ), + ) + rename_keys.append( + ( + f"blocks.{i}.norm1.bias", + f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerNorm.bias", + ), + ) + rename_keys.append( + ( + f"blocks.{i}.attn.proj.weight", + f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Linear.weight", + ), + ) + rename_keys.append( + ( + f"blocks.{i}.attn.proj.bias", + f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Linear.bias", + ), + ) + rename_keys.append( + ( + f"blocks.{i}.ls1.gamma", + f"Transformer.TransformerLayer_{i+1}.Residual_1.LayerScale.weight", + ), + ) + rename_keys.append( + ( + f"blocks.{i}.norm2.weight", + f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerNorm.weight", + ), + ) + rename_keys.append( + ( + f"blocks.{i}.norm2.bias", + f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerNorm.bias", + ), + ) + rename_keys.append( + ( + f"blocks.{i}.mlp.fc1.weight", + f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_1.weight", + ), + ) + rename_keys.append( + ( + f"blocks.{i}.mlp.fc1.bias", + f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_1.bias", + ), + ) + rename_keys.append( + ( + f"blocks.{i}.mlp.fc2.weight", + f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_2.weight", + ), + ) + rename_keys.append( + ( + f"blocks.{i}.mlp.fc2.bias", + f"Transformer.TransformerLayer_{i+1}.Residual_2.FeedForward.Linear_2.bias", + ), + ) + rename_keys.append( + ( + f"blocks.{i}.ls2.gamma", + f"Transformer.TransformerLayer_{i+1}.Residual_2.LayerScale.weight", + ), + ) + + if "register_tokens" in weights: + weights["register_tokens"] = weights["register_tokens"].squeeze(0) + rename_keys.append(("register_tokens", "Registers.Parameter.weight")) + + # rename keys + for old_key, new_key in rename_keys: + weights[new_key] = weights.pop(old_key) + + # split the qkv weights and biases + for i in range(depth): + qkv_weight = weights.pop(f"blocks.{i}.attn.qkv.weight") + q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0) + weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_1.weight"] = q_weight + weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_2.weight"] = k_weight + weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_3.weight"] = v_weight + + qkv_bias = weights.pop(f"blocks.{i}.attn.qkv.bias") + q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0) + weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_1.bias"] = q_bias + weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_2.bias"] = k_bias + weights[f"Transformer.TransformerLayer_{i+1}.Residual_1.SelfAttention.Distribute.Linear_3.bias"] = v_bias + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--weights_path", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + args = parser.parse_args() + + weights = torch.load(args.weights_path) + convert_dinov2_facebook(weights) + torch.save(weights, args.output_path) + + +if __name__ == "__main__": + main() diff --git a/src/refiners/foundationals/dinov2/__init__.py b/src/refiners/foundationals/dinov2/__init__.py new file mode 100644 index 0000000..8e802bb --- /dev/null +++ b/src/refiners/foundationals/dinov2/__init__.py @@ -0,0 +1,29 @@ +from .dinov2 import ( + DINOv2_base, + DINOv2_base_reg, + DINOv2_large, + DINOv2_large_reg, + DINOv2_small, + DINOv2_small_reg, +) +from .vit import ( + ViT, + ViT_base, + ViT_large, + ViT_small, + ViT_tiny, +) + +__all__ = [ + "DINOv2_base", + "DINOv2_base_reg", + "DINOv2_large", + "DINOv2_large_reg", + "DINOv2_small", + "DINOv2_small_reg", + "ViT", + "ViT_base", + "ViT_large", + "ViT_small", + "ViT_tiny", +] diff --git a/src/refiners/foundationals/dinov2/dinov2.py b/src/refiners/foundationals/dinov2/dinov2.py new file mode 100644 index 0000000..bd7a3a7 --- /dev/null +++ b/src/refiners/foundationals/dinov2/dinov2.py @@ -0,0 +1,145 @@ +import torch + +from refiners.foundationals.dinov2.vit import ViT + + +class DINOv2_small(ViT): + def __init__( + self, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__( + embedding_dim=384, + patch_size=14, + image_size=518, + num_layers=12, + num_heads=6, + device=device, + dtype=dtype, + ) + + +class DINOv2_base(ViT): + def __init__( + self, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__( + embedding_dim=768, + patch_size=14, + image_size=518, + num_layers=12, + num_heads=12, + device=device, + dtype=dtype, + ) + + +class DINOv2_large(ViT): + def __init__( + self, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__( + embedding_dim=1024, + patch_size=14, + image_size=518, + num_layers=24, + num_heads=16, + device=device, + dtype=dtype, + ) + + +# TODO: implement SwiGLU layer +# class DINOv2_giant2(ViT): +# def __init__( +# self, +# device: torch.device | str | None = None, +# dtype: torch.dtype | None = None, +# ) -> None: +# super().__init__( +# embedding_dim=1536, +# patch_size=14, +# image_size=518, +# num_layers=40, +# num_heads=24, +# device=device, +# dtype=dtype, +# ) + + +class DINOv2_small_reg(ViT): + def __init__( + self, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__( + embedding_dim=384, + patch_size=14, + image_size=518, + num_layers=12, + num_heads=6, + num_registers=4, + device=device, + dtype=dtype, + ) + + +class DINOv2_base_reg(ViT): + def __init__( + self, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__( + embedding_dim=768, + patch_size=14, + image_size=518, + num_layers=12, + num_heads=12, + num_registers=4, + device=device, + dtype=dtype, + ) + + +class DINOv2_large_reg(ViT): + def __init__( + self, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__( + embedding_dim=1024, + patch_size=14, + image_size=518, + num_layers=24, + num_heads=16, + num_registers=4, + device=device, + dtype=dtype, + ) + + +# TODO: implement SwiGLU layer +# class DINOv2_giant2_reg(ViT): +# def __init__( +# self, +# device: torch.device | str | None = None, +# dtype: torch.dtype | None = None, +# ) -> None: +# super().__init__( +# embedding_dim=1536, +# patch_size=14, +# image_size=518, +# num_layers=40, +# num_heads=24, +# num_registers=4, +# device=device, +# dtype=dtype, +# ) diff --git a/src/refiners/foundationals/dinov2/vit.py b/src/refiners/foundationals/dinov2/vit.py new file mode 100644 index 0000000..f80052c --- /dev/null +++ b/src/refiners/foundationals/dinov2/vit.py @@ -0,0 +1,372 @@ +import torch +from torch import Tensor + +import refiners.fluxion.layers as fl +from refiners.fluxion.layers.activations import Activation + + +class ClassToken(fl.Chain): + """Learnable token representing the class of the input.""" + + def __init__( + self, + embedding_dim: int, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + self.embedding_dim = embedding_dim + + super().__init__( + fl.Parameter( + *(1, embedding_dim), + device=device, + dtype=dtype, + ), + ) + + +class PositionalEncoder(fl.Residual): + """Encode the position of each patch in the input.""" + + def __init__( + self, + sequence_length: int, + embedding_dim: int, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + self.num_patches = sequence_length + self.embedding_dim = embedding_dim + + super().__init__( + fl.Parameter( + *(sequence_length, embedding_dim), + device=device, + dtype=dtype, + ), + ) + + +class LayerScale(fl.WeightedModule): + """Scale the input tensor by a learnable parameter.""" + + def __init__( + self, + embedding_dim: int, + init_value: float = 1.0, + dtype: torch.dtype | None = None, + device: torch.device | str | None = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + + self.register_parameter( + name="weight", + param=torch.nn.Parameter( + torch.full( + size=(embedding_dim,), + fill_value=init_value, + dtype=dtype, + device=device, + ), + ), + ) + + def forward(self, x: Tensor) -> Tensor: + return x * self.weight + + +class FeedForward(fl.Chain): + """Apply two linear transformations interleaved by an activation function.""" + + def __init__( + self, + embedding_dim: int, + feedforward_dim: int, + activation: Activation = fl.GeLU, # type: ignore + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + self.embedding_dim = embedding_dim + self.feedforward_dim = feedforward_dim + + super().__init__( + fl.Linear( + in_features=embedding_dim, + out_features=feedforward_dim, + device=device, + dtype=dtype, + ), + activation(), + fl.Linear( + in_features=feedforward_dim, + out_features=embedding_dim, + device=device, + dtype=dtype, + ), + ) + + +class PatchEncoder(fl.Chain): + """Encode an image into a sequence of patches.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_size = patch_size + + super().__init__( + fl.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=patch_size, + stride=patch_size, + device=device, + dtype=dtype, + ), # (N,3,H,W) -> (N,D,P,P) + fl.Reshape(out_channels, -1), # (N,D,P,P) -> (N,D,P²) + fl.Transpose(1, 2), # (N,D,P²) -> (N,P²,D) + ) + + +class TransformerLayer(fl.Chain): + """Apply a multi-head self-attention mechanism to the input tensor.""" + + def __init__( + self, + embedding_dim: int, + num_heads: int, + norm_eps: float, + mlp_ratio: int, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.norm_eps = norm_eps + self.mlp_ratio = mlp_ratio + + super().__init__( + fl.Residual( + fl.LayerNorm( + normalized_shape=embedding_dim, + eps=norm_eps, + device=device, + dtype=dtype, + ), + fl.SelfAttention( + embedding_dim=embedding_dim, + num_heads=num_heads, + device=device, + dtype=dtype, + ), + LayerScale( + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ), + ), + fl.Residual( + fl.LayerNorm( + normalized_shape=embedding_dim, + eps=norm_eps, + device=device, + dtype=dtype, + ), + FeedForward( + embedding_dim=embedding_dim, + feedforward_dim=embedding_dim * mlp_ratio, + device=device, + dtype=dtype, + ), + LayerScale( + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ), + ), + ) + + +class Transformer(fl.Chain): + """Alias for a Chain of TransformerLayer.""" + + +class Registers(fl.Concatenate): + """Insert register tokens between CLS token and patches.""" + + def __init__( + self, + num_registers: int, + embedding_dim: int, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + self.num_registers = num_registers + self.embedding_dim = embedding_dim + + super().__init__( + fl.Slicing(dim=1, end=1), + fl.Parameter( + *(num_registers, embedding_dim), + device=device, + dtype=dtype, + ), + fl.Slicing(dim=1, start=1), + dim=1, + ) + + +class ViT(fl.Chain): + """Vision Transformer (ViT). + + see https://arxiv.org/abs/2010.11929v2 + """ + + def __init__( + self, + embedding_dim: int = 768, + patch_size: int = 16, + image_size: int = 224, + num_layers: int = 12, + num_heads: int = 12, + norm_eps: float = 1e-6, + mlp_ratio: int = 4, + num_registers: int = 0, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + num_patches = image_size // patch_size + self.embedding_dim = embedding_dim + self.patch_size = patch_size + self.image_size = image_size + self.num_layers = num_layers + self.num_heads = num_heads + self.norm_eps = norm_eps + self.mlp_ratio = mlp_ratio + self.num_registers = num_registers + + super().__init__( + fl.Concatenate( + ClassToken( + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ), + PatchEncoder( + in_channels=3, + out_channels=embedding_dim, + patch_size=patch_size, + device=device, + dtype=dtype, + ), + dim=1, + ), + PositionalEncoder( + sequence_length=num_patches**2 + 1, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ), + Transformer( + TransformerLayer( + embedding_dim=embedding_dim, + num_heads=num_heads, + norm_eps=norm_eps, + mlp_ratio=mlp_ratio, + device=device, + dtype=dtype, + ) + for _ in range(num_layers) + ), + fl.LayerNorm( + normalized_shape=embedding_dim, + eps=norm_eps, + device=device, + dtype=dtype, + ), + ) + + if self.num_registers > 0: + registers = Registers( + num_registers=num_registers, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ) + self.insert_before_type(Transformer, registers) + + +class ViT_tiny(ViT): + def __init__( + self, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__( + embedding_dim=192, + patch_size=16, + image_size=224, + num_layers=12, + num_heads=3, + device=device, + dtype=dtype, + ) + + +class ViT_small(ViT): + def __init__( + self, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__( + embedding_dim=384, + patch_size=16, + image_size=224, + num_layers=12, + num_heads=6, + device=device, + dtype=dtype, + ) + + +class ViT_base(ViT): + def __init__( + self, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__( + embedding_dim=768, + patch_size=16, + image_size=224, + num_layers=12, + num_heads=12, + device=device, + dtype=dtype, + ) + + +class ViT_large(ViT): + def __init__( + self, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__( + embedding_dim=1024, + patch_size=16, + image_size=224, + num_layers=24, + num_heads=16, + device=device, + dtype=dtype, + )