implement dinov2 positional embedding interpolation

This commit is contained in:
Laurent 2024-03-29 17:42:08 +00:00 committed by Laureηt
parent 0336bc78b5
commit 4f94dfb494
5 changed files with 107 additions and 17 deletions

View file

@ -144,7 +144,7 @@ exclude_also = [
[tool.typos.default] [tool.typos.default]
extend-words = { adaptee = "adaptee" } extend-words = { adaptee = "adaptee" }
extend-ignore-identifiers-re = ["NDArray*", "interm"] extend-ignore-identifiers-re = ["NDArray*", "interm", "af000ded"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
filterwarnings = [ filterwarnings = [

View file

@ -20,7 +20,7 @@ def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
rename_keys: list[tuple[str, str]] = [ rename_keys: list[tuple[str, str]] = [
("cls_token", "Concatenate.ClassToken.Parameter.weight"), ("cls_token", "Concatenate.ClassToken.Parameter.weight"),
("pos_embed", "PositionalEncoder.Parameter.weight"), ("pos_embed", "PositionalEncoder.PositionalEmbedding.Parameter.weight"),
("patch_embed.proj.weight", "Concatenate.PatchEncoder.Conv2d.weight"), ("patch_embed.proj.weight", "Concatenate.PatchEncoder.Conv2d.weight"),
("patch_embed.proj.bias", "Concatenate.PatchEncoder.Conv2d.bias"), ("patch_embed.proj.bias", "Concatenate.PatchEncoder.Conv2d.bias"),
("norm.weight", "LayerNorm.weight"), ("norm.weight", "LayerNorm.weight"),

View file

@ -688,37 +688,37 @@ def convert_dinov2():
"convert_dinov2.py", "convert_dinov2.py",
"tests/weights/dinov2_vits14_pretrain.pth", "tests/weights/dinov2_vits14_pretrain.pth",
"tests/weights/dinov2_vits14_pretrain.safetensors", "tests/weights/dinov2_vits14_pretrain.safetensors",
expected_hash="b7f9b294", expected_hash="af000ded",
) )
run_conversion_script( run_conversion_script(
"convert_dinov2.py", "convert_dinov2.py",
"tests/weights/dinov2_vitb14_pretrain.pth", "tests/weights/dinov2_vitb14_pretrain.pth",
"tests/weights/dinov2_vitb14_pretrain.safetensors", "tests/weights/dinov2_vitb14_pretrain.safetensors",
expected_hash="d72c767b", expected_hash="d6294087",
) )
run_conversion_script( run_conversion_script(
"convert_dinov2.py", "convert_dinov2.py",
"tests/weights/dinov2_vitl14_pretrain.pth", "tests/weights/dinov2_vitl14_pretrain.pth",
"tests/weights/dinov2_vitl14_pretrain.safetensors", "tests/weights/dinov2_vitl14_pretrain.safetensors",
expected_hash="71eb98d1", expected_hash="ddd4819f",
) )
run_conversion_script( run_conversion_script(
"convert_dinov2.py", "convert_dinov2.py",
"tests/weights/dinov2_vits14_reg4_pretrain.pth", "tests/weights/dinov2_vits14_reg4_pretrain.pth",
"tests/weights/dinov2_vits14_reg4_pretrain.safetensors", "tests/weights/dinov2_vits14_reg4_pretrain.safetensors",
expected_hash="89118b46", expected_hash="080247c7",
) )
run_conversion_script( run_conversion_script(
"convert_dinov2.py", "convert_dinov2.py",
"tests/weights/dinov2_vitb14_reg4_pretrain.pth", "tests/weights/dinov2_vitb14_reg4_pretrain.pth",
"tests/weights/dinov2_vitb14_reg4_pretrain.safetensors", "tests/weights/dinov2_vitb14_reg4_pretrain.safetensors",
expected_hash="b0296f77", expected_hash="5cd4d408",
) )
run_conversion_script( run_conversion_script(
"convert_dinov2.py", "convert_dinov2.py",
"tests/weights/dinov2_vitl14_reg4_pretrain.pth", "tests/weights/dinov2_vitl14_reg4_pretrain.pth",
"tests/weights/dinov2_vitl14_reg4_pretrain.safetensors", "tests/weights/dinov2_vitl14_reg4_pretrain.safetensors",
expected_hash="b3d877dc", expected_hash="b1221702",
) )

View file

@ -146,6 +146,7 @@ class DINOv2_small_reg(ViT):
num_layers (int): 12 num_layers (int): 12
num_heads (int): 6 num_heads (int): 6
num_registers (int): 4 num_registers (int): 4
interpolate_antialias (bool): True
""" """
def __init__( def __init__(
@ -166,6 +167,7 @@ class DINOv2_small_reg(ViT):
num_layers=12, num_layers=12,
num_heads=6, num_heads=6,
num_registers=4, num_registers=4,
interpolate_antialias=True,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -185,6 +187,7 @@ class DINOv2_base_reg(ViT):
num_layers (int): 12 num_layers (int): 12
num_heads (int): 12 num_heads (int): 12
num_registers (int): 4 num_registers (int): 4
interpolate_antialias (bool): True
""" """
def __init__( def __init__(
@ -205,6 +208,7 @@ class DINOv2_base_reg(ViT):
num_layers=12, num_layers=12,
num_heads=12, num_heads=12,
num_registers=4, num_registers=4,
interpolate_antialias=True,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -224,6 +228,7 @@ class DINOv2_large_reg(ViT):
num_layers (int): 24 num_layers (int): 24
num_heads (int): 16 num_heads (int): 16
num_registers (int): 4 num_registers (int): 4
interpolate_antialias (bool): True
""" """
def __init__( def __init__(
@ -244,6 +249,7 @@ class DINOv2_large_reg(ViT):
num_layers=24, num_layers=24,
num_heads=16, num_heads=16,
num_registers=4, num_registers=4,
interpolate_antialias=True,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -263,6 +269,7 @@ class DINOv2_large_reg(ViT):
# num_layers=40, # num_layers=40,
# num_heads=24, # num_heads=24,
# num_registers=4, # num_registers=4,
# interpolate_antialias=True,
# device=device, # device=device,
# dtype=dtype, # dtype=dtype,
# ) # )

View file

@ -1,10 +1,13 @@
from math import sqrt
from typing import cast from typing import cast
import torch import torch
from torch import Tensor from torch import Tensor
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.context import Contexts
from refiners.fluxion.layers.activations import Activation from refiners.fluxion.layers.activations import Activation
from refiners.fluxion.utils import interpolate
class ClassToken(fl.Chain): class ClassToken(fl.Chain):
@ -27,18 +30,20 @@ class ClassToken(fl.Chain):
) )
class PositionalEncoder(fl.Residual): class PositionalEmbedding(fl.Chain):
"""Encode the position of each patch in the input.""" """Learnable positional embedding."""
def __init__( def __init__(
self, self,
sequence_length: int, sequence_length: int,
embedding_dim: int, embedding_dim: int,
patch_size: int,
device: torch.device | str | None = None, device: torch.device | str | None = None,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
) -> None: ) -> None:
self.num_patches = sequence_length self.sequence_length = sequence_length
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.patch_size = patch_size
super().__init__( super().__init__(
fl.Parameter( fl.Parameter(
@ -49,6 +54,55 @@ class PositionalEncoder(fl.Residual):
) )
class InterpolateEmbedding(fl.Module):
"""Interpolate the positional embeddings to match the input shape."""
def __init__(
self,
mode: str,
antialias: bool,
patch_size: int,
) -> None:
super().__init__()
self.mode = mode
self.antialias = antialias
self.patch_size = patch_size
def forward(
self,
x: torch.Tensor,
input: torch.Tensor,
) -> torch.Tensor:
cls_embed = x[:, :1, :] # -> (1, 1, D)
patch_embed = x[:, 1:, :] # -> (1, N, D)
N = patch_embed.shape[1]
D = patch_embed.shape[2]
M = int(sqrt(N))
W = input.shape[2]
H = input.shape[3]
assert M * M == N, "The sequence length must be a square number."
patch_embed = patch_embed.reshape(1, M, M, D) # -> (1, M, M, D)
patch_embed = patch_embed.permute(0, 3, 1, 2) # -> (1, D, M, M)
patch_embed = interpolate(
x=patch_embed.to(dtype=torch.float32),
mode=self.mode,
antialias=self.antialias,
size=torch.Size(
(
W // self.patch_size,
H // self.patch_size,
)
),
).to(dtype=cls_embed.dtype) # -> (1, D, w, h)
patch_embed = patch_embed.permute(0, 2, 3, 1) # -> (1, w, h, D)
patch_embed = patch_embed.reshape(1, -1, D) # -> (1, w*h, D)
x = torch.cat((cls_embed, patch_embed), dim=1) # -> (1, w*h+1, D)
return x
class LayerScale(fl.WeightedModule): class LayerScale(fl.WeightedModule):
"""Scale the input tensor by a learnable parameter.""" """Scale the input tensor by a learnable parameter."""
@ -125,6 +179,7 @@ class PatchEncoder(fl.Chain):
self.patch_size = patch_size self.patch_size = patch_size
super().__init__( super().__init__(
fl.SetContext(context="dinov2_vit", key="input"), # save the original input
fl.Conv2d( fl.Conv2d(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
@ -201,6 +256,10 @@ class Transformer(fl.Chain):
"""Alias for a Chain of TransformerLayer.""" """Alias for a Chain of TransformerLayer."""
class PositionalEncoder(fl.Residual):
"""Alias for a Residual."""
class Registers(fl.Concatenate): class Registers(fl.Concatenate):
"""Insert register tokens between CLS token and patches.""" """Insert register tokens between CLS token and patches."""
@ -243,6 +302,8 @@ class ViT(fl.Chain):
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
mlp_ratio: int = 4, mlp_ratio: int = 4,
num_registers: int = 0, num_registers: int = 0,
interpolate_antialias: bool = False,
interpolate_mode: str = "bicubic",
device: torch.device | str | None = None, device: torch.device | str | None = None,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
) -> None: ) -> None:
@ -257,6 +318,8 @@ class ViT(fl.Chain):
norm_eps: The epsilon value for normalization. norm_eps: The epsilon value for normalization.
mlp_ratio: The ratio for the multi-layer perceptron (MLP). mlp_ratio: The ratio for the multi-layer perceptron (MLP).
num_registers: The number of registers. num_registers: The number of registers.
interpolate_antialias: Whether to use antialiasing for interpolation.
interpolate_mode: The interpolation mode.
device: The PyTorch device to use. device: The PyTorch device to use.
dtype: The PyTorch data type to use. dtype: The PyTorch data type to use.
""" """
@ -286,19 +349,32 @@ class ViT(fl.Chain):
), ),
dim=1, dim=1,
), ),
# TODO: support https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179
PositionalEncoder( PositionalEncoder(
PositionalEmbedding(
sequence_length=num_patches**2 + 1, sequence_length=num_patches**2 + 1,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
patch_size=patch_size,
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),
fl.Chain(
fl.Parallel(
fl.Identity(),
fl.UseContext(context="dinov2_vit", key="input"),
),
InterpolateEmbedding(
mode=interpolate_mode,
antialias=interpolate_antialias,
patch_size=patch_size,
),
),
),
Transformer( Transformer(
TransformerLayer( TransformerLayer(
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
num_heads=num_heads, num_heads=num_heads,
norm_eps=norm_eps,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
norm_eps=norm_eps,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
@ -320,3 +396,10 @@ class ViT(fl.Chain):
dtype=dtype, dtype=dtype,
) )
self.insert_before_type(Transformer, registers) self.insert_before_type(Transformer, registers)
def init_context(self) -> Contexts:
return {
"dinov2_vit": {
"input": None,
},
}