mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
implement dinov2 positional embedding interpolation
This commit is contained in:
parent
0336bc78b5
commit
4f94dfb494
|
@ -144,7 +144,7 @@ exclude_also = [
|
|||
|
||||
[tool.typos.default]
|
||||
extend-words = { adaptee = "adaptee" }
|
||||
extend-ignore-identifiers-re = ["NDArray*", "interm"]
|
||||
extend-ignore-identifiers-re = ["NDArray*", "interm", "af000ded"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
filterwarnings = [
|
||||
|
|
|
@ -20,7 +20,7 @@ def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
|
|||
|
||||
rename_keys: list[tuple[str, str]] = [
|
||||
("cls_token", "Concatenate.ClassToken.Parameter.weight"),
|
||||
("pos_embed", "PositionalEncoder.Parameter.weight"),
|
||||
("pos_embed", "PositionalEncoder.PositionalEmbedding.Parameter.weight"),
|
||||
("patch_embed.proj.weight", "Concatenate.PatchEncoder.Conv2d.weight"),
|
||||
("patch_embed.proj.bias", "Concatenate.PatchEncoder.Conv2d.bias"),
|
||||
("norm.weight", "LayerNorm.weight"),
|
||||
|
|
|
@ -688,37 +688,37 @@ def convert_dinov2():
|
|||
"convert_dinov2.py",
|
||||
"tests/weights/dinov2_vits14_pretrain.pth",
|
||||
"tests/weights/dinov2_vits14_pretrain.safetensors",
|
||||
expected_hash="b7f9b294",
|
||||
expected_hash="af000ded",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_dinov2.py",
|
||||
"tests/weights/dinov2_vitb14_pretrain.pth",
|
||||
"tests/weights/dinov2_vitb14_pretrain.safetensors",
|
||||
expected_hash="d72c767b",
|
||||
expected_hash="d6294087",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_dinov2.py",
|
||||
"tests/weights/dinov2_vitl14_pretrain.pth",
|
||||
"tests/weights/dinov2_vitl14_pretrain.safetensors",
|
||||
expected_hash="71eb98d1",
|
||||
expected_hash="ddd4819f",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_dinov2.py",
|
||||
"tests/weights/dinov2_vits14_reg4_pretrain.pth",
|
||||
"tests/weights/dinov2_vits14_reg4_pretrain.safetensors",
|
||||
expected_hash="89118b46",
|
||||
expected_hash="080247c7",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_dinov2.py",
|
||||
"tests/weights/dinov2_vitb14_reg4_pretrain.pth",
|
||||
"tests/weights/dinov2_vitb14_reg4_pretrain.safetensors",
|
||||
expected_hash="b0296f77",
|
||||
expected_hash="5cd4d408",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_dinov2.py",
|
||||
"tests/weights/dinov2_vitl14_reg4_pretrain.pth",
|
||||
"tests/weights/dinov2_vitl14_reg4_pretrain.safetensors",
|
||||
expected_hash="b3d877dc",
|
||||
expected_hash="b1221702",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -146,6 +146,7 @@ class DINOv2_small_reg(ViT):
|
|||
num_layers (int): 12
|
||||
num_heads (int): 6
|
||||
num_registers (int): 4
|
||||
interpolate_antialias (bool): True
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -166,6 +167,7 @@ class DINOv2_small_reg(ViT):
|
|||
num_layers=12,
|
||||
num_heads=6,
|
||||
num_registers=4,
|
||||
interpolate_antialias=True,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -185,6 +187,7 @@ class DINOv2_base_reg(ViT):
|
|||
num_layers (int): 12
|
||||
num_heads (int): 12
|
||||
num_registers (int): 4
|
||||
interpolate_antialias (bool): True
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -205,6 +208,7 @@ class DINOv2_base_reg(ViT):
|
|||
num_layers=12,
|
||||
num_heads=12,
|
||||
num_registers=4,
|
||||
interpolate_antialias=True,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -224,6 +228,7 @@ class DINOv2_large_reg(ViT):
|
|||
num_layers (int): 24
|
||||
num_heads (int): 16
|
||||
num_registers (int): 4
|
||||
interpolate_antialias (bool): True
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -244,6 +249,7 @@ class DINOv2_large_reg(ViT):
|
|||
num_layers=24,
|
||||
num_heads=16,
|
||||
num_registers=4,
|
||||
interpolate_antialias=True,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -263,6 +269,7 @@ class DINOv2_large_reg(ViT):
|
|||
# num_layers=40,
|
||||
# num_heads=24,
|
||||
# num_registers=4,
|
||||
# interpolate_antialias=True,
|
||||
# device=device,
|
||||
# dtype=dtype,
|
||||
# )
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
from math import sqrt
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.context import Contexts
|
||||
from refiners.fluxion.layers.activations import Activation
|
||||
from refiners.fluxion.utils import interpolate
|
||||
|
||||
|
||||
class ClassToken(fl.Chain):
|
||||
|
@ -27,18 +30,20 @@ class ClassToken(fl.Chain):
|
|||
)
|
||||
|
||||
|
||||
class PositionalEncoder(fl.Residual):
|
||||
"""Encode the position of each patch in the input."""
|
||||
class PositionalEmbedding(fl.Chain):
|
||||
"""Learnable positional embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sequence_length: int,
|
||||
embedding_dim: int,
|
||||
patch_size: int,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
self.num_patches = sequence_length
|
||||
self.sequence_length = sequence_length
|
||||
self.embedding_dim = embedding_dim
|
||||
self.patch_size = patch_size
|
||||
|
||||
super().__init__(
|
||||
fl.Parameter(
|
||||
|
@ -49,6 +54,55 @@ class PositionalEncoder(fl.Residual):
|
|||
)
|
||||
|
||||
|
||||
class InterpolateEmbedding(fl.Module):
|
||||
"""Interpolate the positional embeddings to match the input shape."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str,
|
||||
antialias: bool,
|
||||
patch_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.mode = mode
|
||||
self.antialias = antialias
|
||||
self.patch_size = patch_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
cls_embed = x[:, :1, :] # -> (1, 1, D)
|
||||
patch_embed = x[:, 1:, :] # -> (1, N, D)
|
||||
|
||||
N = patch_embed.shape[1]
|
||||
D = patch_embed.shape[2]
|
||||
M = int(sqrt(N))
|
||||
W = input.shape[2]
|
||||
H = input.shape[3]
|
||||
assert M * M == N, "The sequence length must be a square number."
|
||||
|
||||
patch_embed = patch_embed.reshape(1, M, M, D) # -> (1, M, M, D)
|
||||
patch_embed = patch_embed.permute(0, 3, 1, 2) # -> (1, D, M, M)
|
||||
patch_embed = interpolate(
|
||||
x=patch_embed.to(dtype=torch.float32),
|
||||
mode=self.mode,
|
||||
antialias=self.antialias,
|
||||
size=torch.Size(
|
||||
(
|
||||
W // self.patch_size,
|
||||
H // self.patch_size,
|
||||
)
|
||||
),
|
||||
).to(dtype=cls_embed.dtype) # -> (1, D, w, h)
|
||||
patch_embed = patch_embed.permute(0, 2, 3, 1) # -> (1, w, h, D)
|
||||
patch_embed = patch_embed.reshape(1, -1, D) # -> (1, w*h, D)
|
||||
|
||||
x = torch.cat((cls_embed, patch_embed), dim=1) # -> (1, w*h+1, D)
|
||||
return x
|
||||
|
||||
|
||||
class LayerScale(fl.WeightedModule):
|
||||
"""Scale the input tensor by a learnable parameter."""
|
||||
|
||||
|
@ -125,6 +179,7 @@ class PatchEncoder(fl.Chain):
|
|||
self.patch_size = patch_size
|
||||
|
||||
super().__init__(
|
||||
fl.SetContext(context="dinov2_vit", key="input"), # save the original input
|
||||
fl.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
|
@ -201,6 +256,10 @@ class Transformer(fl.Chain):
|
|||
"""Alias for a Chain of TransformerLayer."""
|
||||
|
||||
|
||||
class PositionalEncoder(fl.Residual):
|
||||
"""Alias for a Residual."""
|
||||
|
||||
|
||||
class Registers(fl.Concatenate):
|
||||
"""Insert register tokens between CLS token and patches."""
|
||||
|
||||
|
@ -243,6 +302,8 @@ class ViT(fl.Chain):
|
|||
norm_eps: float = 1e-6,
|
||||
mlp_ratio: int = 4,
|
||||
num_registers: int = 0,
|
||||
interpolate_antialias: bool = False,
|
||||
interpolate_mode: str = "bicubic",
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
|
@ -257,6 +318,8 @@ class ViT(fl.Chain):
|
|||
norm_eps: The epsilon value for normalization.
|
||||
mlp_ratio: The ratio for the multi-layer perceptron (MLP).
|
||||
num_registers: The number of registers.
|
||||
interpolate_antialias: Whether to use antialiasing for interpolation.
|
||||
interpolate_mode: The interpolation mode.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
|
@ -286,19 +349,32 @@ class ViT(fl.Chain):
|
|||
),
|
||||
dim=1,
|
||||
),
|
||||
# TODO: support https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179
|
||||
PositionalEncoder(
|
||||
PositionalEmbedding(
|
||||
sequence_length=num_patches**2 + 1,
|
||||
embedding_dim=embedding_dim,
|
||||
patch_size=patch_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Chain(
|
||||
fl.Parallel(
|
||||
fl.Identity(),
|
||||
fl.UseContext(context="dinov2_vit", key="input"),
|
||||
),
|
||||
InterpolateEmbedding(
|
||||
mode=interpolate_mode,
|
||||
antialias=interpolate_antialias,
|
||||
patch_size=patch_size,
|
||||
),
|
||||
),
|
||||
),
|
||||
Transformer(
|
||||
TransformerLayer(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=num_heads,
|
||||
norm_eps=norm_eps,
|
||||
mlp_ratio=mlp_ratio,
|
||||
norm_eps=norm_eps,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -320,3 +396,10 @@ class ViT(fl.Chain):
|
|||
dtype=dtype,
|
||||
)
|
||||
self.insert_before_type(Transformer, registers)
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {
|
||||
"dinov2_vit": {
|
||||
"input": None,
|
||||
},
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue