From 4f94dfb49419c9bbf2d61f9ad3cb3ae742052c1c Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 29 Mar 2024 17:42:08 +0000 Subject: [PATCH] implement dinov2 positional embedding interpolation --- pyproject.toml | 2 +- scripts/conversion/convert_dinov2.py | 2 +- scripts/prepare_test_weights.py | 12 +-- src/refiners/foundationals/dinov2/dinov2.py | 7 ++ src/refiners/foundationals/dinov2/vit.py | 101 ++++++++++++++++++-- 5 files changed, 107 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 901a6b6..47ad7de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,7 +144,7 @@ exclude_also = [ [tool.typos.default] extend-words = { adaptee = "adaptee" } -extend-ignore-identifiers-re = ["NDArray*", "interm"] +extend-ignore-identifiers-re = ["NDArray*", "interm", "af000ded"] [tool.pytest.ini_options] filterwarnings = [ diff --git a/scripts/conversion/convert_dinov2.py b/scripts/conversion/convert_dinov2.py index 3d51ce1..9e907b4 100644 --- a/scripts/conversion/convert_dinov2.py +++ b/scripts/conversion/convert_dinov2.py @@ -20,7 +20,7 @@ def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None: rename_keys: list[tuple[str, str]] = [ ("cls_token", "Concatenate.ClassToken.Parameter.weight"), - ("pos_embed", "PositionalEncoder.Parameter.weight"), + ("pos_embed", "PositionalEncoder.PositionalEmbedding.Parameter.weight"), ("patch_embed.proj.weight", "Concatenate.PatchEncoder.Conv2d.weight"), ("patch_embed.proj.bias", "Concatenate.PatchEncoder.Conv2d.bias"), ("norm.weight", "LayerNorm.weight"), diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index 357d175..3dc96c2 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -688,37 +688,37 @@ def convert_dinov2(): "convert_dinov2.py", "tests/weights/dinov2_vits14_pretrain.pth", "tests/weights/dinov2_vits14_pretrain.safetensors", - expected_hash="b7f9b294", + expected_hash="af000ded", ) run_conversion_script( "convert_dinov2.py", "tests/weights/dinov2_vitb14_pretrain.pth", "tests/weights/dinov2_vitb14_pretrain.safetensors", - expected_hash="d72c767b", + expected_hash="d6294087", ) run_conversion_script( "convert_dinov2.py", "tests/weights/dinov2_vitl14_pretrain.pth", "tests/weights/dinov2_vitl14_pretrain.safetensors", - expected_hash="71eb98d1", + expected_hash="ddd4819f", ) run_conversion_script( "convert_dinov2.py", "tests/weights/dinov2_vits14_reg4_pretrain.pth", "tests/weights/dinov2_vits14_reg4_pretrain.safetensors", - expected_hash="89118b46", + expected_hash="080247c7", ) run_conversion_script( "convert_dinov2.py", "tests/weights/dinov2_vitb14_reg4_pretrain.pth", "tests/weights/dinov2_vitb14_reg4_pretrain.safetensors", - expected_hash="b0296f77", + expected_hash="5cd4d408", ) run_conversion_script( "convert_dinov2.py", "tests/weights/dinov2_vitl14_reg4_pretrain.pth", "tests/weights/dinov2_vitl14_reg4_pretrain.safetensors", - expected_hash="b3d877dc", + expected_hash="b1221702", ) diff --git a/src/refiners/foundationals/dinov2/dinov2.py b/src/refiners/foundationals/dinov2/dinov2.py index 7011cb0..13d598f 100644 --- a/src/refiners/foundationals/dinov2/dinov2.py +++ b/src/refiners/foundationals/dinov2/dinov2.py @@ -146,6 +146,7 @@ class DINOv2_small_reg(ViT): num_layers (int): 12 num_heads (int): 6 num_registers (int): 4 + interpolate_antialias (bool): True """ def __init__( @@ -166,6 +167,7 @@ class DINOv2_small_reg(ViT): num_layers=12, num_heads=6, num_registers=4, + interpolate_antialias=True, device=device, dtype=dtype, ) @@ -185,6 +187,7 @@ class DINOv2_base_reg(ViT): num_layers (int): 12 num_heads (int): 12 num_registers (int): 4 + interpolate_antialias (bool): True """ def __init__( @@ -205,6 +208,7 @@ class DINOv2_base_reg(ViT): num_layers=12, num_heads=12, num_registers=4, + interpolate_antialias=True, device=device, dtype=dtype, ) @@ -224,6 +228,7 @@ class DINOv2_large_reg(ViT): num_layers (int): 24 num_heads (int): 16 num_registers (int): 4 + interpolate_antialias (bool): True """ def __init__( @@ -244,6 +249,7 @@ class DINOv2_large_reg(ViT): num_layers=24, num_heads=16, num_registers=4, + interpolate_antialias=True, device=device, dtype=dtype, ) @@ -263,6 +269,7 @@ class DINOv2_large_reg(ViT): # num_layers=40, # num_heads=24, # num_registers=4, +# interpolate_antialias=True, # device=device, # dtype=dtype, # ) diff --git a/src/refiners/foundationals/dinov2/vit.py b/src/refiners/foundationals/dinov2/vit.py index 407faac..e2443ef 100644 --- a/src/refiners/foundationals/dinov2/vit.py +++ b/src/refiners/foundationals/dinov2/vit.py @@ -1,10 +1,13 @@ +from math import sqrt from typing import cast import torch from torch import Tensor import refiners.fluxion.layers as fl +from refiners.fluxion.context import Contexts from refiners.fluxion.layers.activations import Activation +from refiners.fluxion.utils import interpolate class ClassToken(fl.Chain): @@ -27,18 +30,20 @@ class ClassToken(fl.Chain): ) -class PositionalEncoder(fl.Residual): - """Encode the position of each patch in the input.""" +class PositionalEmbedding(fl.Chain): + """Learnable positional embedding.""" def __init__( self, sequence_length: int, embedding_dim: int, + patch_size: int, device: torch.device | str | None = None, dtype: torch.dtype | None = None, ) -> None: - self.num_patches = sequence_length + self.sequence_length = sequence_length self.embedding_dim = embedding_dim + self.patch_size = patch_size super().__init__( fl.Parameter( @@ -49,6 +54,55 @@ class PositionalEncoder(fl.Residual): ) +class InterpolateEmbedding(fl.Module): + """Interpolate the positional embeddings to match the input shape.""" + + def __init__( + self, + mode: str, + antialias: bool, + patch_size: int, + ) -> None: + super().__init__() + self.mode = mode + self.antialias = antialias + self.patch_size = patch_size + + def forward( + self, + x: torch.Tensor, + input: torch.Tensor, + ) -> torch.Tensor: + cls_embed = x[:, :1, :] # -> (1, 1, D) + patch_embed = x[:, 1:, :] # -> (1, N, D) + + N = patch_embed.shape[1] + D = patch_embed.shape[2] + M = int(sqrt(N)) + W = input.shape[2] + H = input.shape[3] + assert M * M == N, "The sequence length must be a square number." + + patch_embed = patch_embed.reshape(1, M, M, D) # -> (1, M, M, D) + patch_embed = patch_embed.permute(0, 3, 1, 2) # -> (1, D, M, M) + patch_embed = interpolate( + x=patch_embed.to(dtype=torch.float32), + mode=self.mode, + antialias=self.antialias, + size=torch.Size( + ( + W // self.patch_size, + H // self.patch_size, + ) + ), + ).to(dtype=cls_embed.dtype) # -> (1, D, w, h) + patch_embed = patch_embed.permute(0, 2, 3, 1) # -> (1, w, h, D) + patch_embed = patch_embed.reshape(1, -1, D) # -> (1, w*h, D) + + x = torch.cat((cls_embed, patch_embed), dim=1) # -> (1, w*h+1, D) + return x + + class LayerScale(fl.WeightedModule): """Scale the input tensor by a learnable parameter.""" @@ -125,6 +179,7 @@ class PatchEncoder(fl.Chain): self.patch_size = patch_size super().__init__( + fl.SetContext(context="dinov2_vit", key="input"), # save the original input fl.Conv2d( in_channels=in_channels, out_channels=out_channels, @@ -201,6 +256,10 @@ class Transformer(fl.Chain): """Alias for a Chain of TransformerLayer.""" +class PositionalEncoder(fl.Residual): + """Alias for a Residual.""" + + class Registers(fl.Concatenate): """Insert register tokens between CLS token and patches.""" @@ -243,6 +302,8 @@ class ViT(fl.Chain): norm_eps: float = 1e-6, mlp_ratio: int = 4, num_registers: int = 0, + interpolate_antialias: bool = False, + interpolate_mode: str = "bicubic", device: torch.device | str | None = None, dtype: torch.dtype | None = None, ) -> None: @@ -257,6 +318,8 @@ class ViT(fl.Chain): norm_eps: The epsilon value for normalization. mlp_ratio: The ratio for the multi-layer perceptron (MLP). num_registers: The number of registers. + interpolate_antialias: Whether to use antialiasing for interpolation. + interpolate_mode: The interpolation mode. device: The PyTorch device to use. dtype: The PyTorch data type to use. """ @@ -286,19 +349,32 @@ class ViT(fl.Chain): ), dim=1, ), - # TODO: support https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179 PositionalEncoder( - sequence_length=num_patches**2 + 1, - embedding_dim=embedding_dim, - device=device, - dtype=dtype, + PositionalEmbedding( + sequence_length=num_patches**2 + 1, + embedding_dim=embedding_dim, + patch_size=patch_size, + device=device, + dtype=dtype, + ), + fl.Chain( + fl.Parallel( + fl.Identity(), + fl.UseContext(context="dinov2_vit", key="input"), + ), + InterpolateEmbedding( + mode=interpolate_mode, + antialias=interpolate_antialias, + patch_size=patch_size, + ), + ), ), Transformer( TransformerLayer( embedding_dim=embedding_dim, num_heads=num_heads, - norm_eps=norm_eps, mlp_ratio=mlp_ratio, + norm_eps=norm_eps, device=device, dtype=dtype, ) @@ -320,3 +396,10 @@ class ViT(fl.Chain): dtype=dtype, ) self.insert_before_type(Transformer, registers) + + def init_context(self) -> Contexts: + return { + "dinov2_vit": { + "input": None, + }, + }