mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
implement dinov2 positional embedding interpolation
This commit is contained in:
parent
0336bc78b5
commit
4f94dfb494
|
@ -144,7 +144,7 @@ exclude_also = [
|
||||||
|
|
||||||
[tool.typos.default]
|
[tool.typos.default]
|
||||||
extend-words = { adaptee = "adaptee" }
|
extend-words = { adaptee = "adaptee" }
|
||||||
extend-ignore-identifiers-re = ["NDArray*", "interm"]
|
extend-ignore-identifiers-re = ["NDArray*", "interm", "af000ded"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
filterwarnings = [
|
filterwarnings = [
|
||||||
|
|
|
@ -20,7 +20,7 @@ def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
|
||||||
rename_keys: list[tuple[str, str]] = [
|
rename_keys: list[tuple[str, str]] = [
|
||||||
("cls_token", "Concatenate.ClassToken.Parameter.weight"),
|
("cls_token", "Concatenate.ClassToken.Parameter.weight"),
|
||||||
("pos_embed", "PositionalEncoder.Parameter.weight"),
|
("pos_embed", "PositionalEncoder.PositionalEmbedding.Parameter.weight"),
|
||||||
("patch_embed.proj.weight", "Concatenate.PatchEncoder.Conv2d.weight"),
|
("patch_embed.proj.weight", "Concatenate.PatchEncoder.Conv2d.weight"),
|
||||||
("patch_embed.proj.bias", "Concatenate.PatchEncoder.Conv2d.bias"),
|
("patch_embed.proj.bias", "Concatenate.PatchEncoder.Conv2d.bias"),
|
||||||
("norm.weight", "LayerNorm.weight"),
|
("norm.weight", "LayerNorm.weight"),
|
||||||
|
|
|
@ -688,37 +688,37 @@ def convert_dinov2():
|
||||||
"convert_dinov2.py",
|
"convert_dinov2.py",
|
||||||
"tests/weights/dinov2_vits14_pretrain.pth",
|
"tests/weights/dinov2_vits14_pretrain.pth",
|
||||||
"tests/weights/dinov2_vits14_pretrain.safetensors",
|
"tests/weights/dinov2_vits14_pretrain.safetensors",
|
||||||
expected_hash="b7f9b294",
|
expected_hash="af000ded",
|
||||||
)
|
)
|
||||||
run_conversion_script(
|
run_conversion_script(
|
||||||
"convert_dinov2.py",
|
"convert_dinov2.py",
|
||||||
"tests/weights/dinov2_vitb14_pretrain.pth",
|
"tests/weights/dinov2_vitb14_pretrain.pth",
|
||||||
"tests/weights/dinov2_vitb14_pretrain.safetensors",
|
"tests/weights/dinov2_vitb14_pretrain.safetensors",
|
||||||
expected_hash="d72c767b",
|
expected_hash="d6294087",
|
||||||
)
|
)
|
||||||
run_conversion_script(
|
run_conversion_script(
|
||||||
"convert_dinov2.py",
|
"convert_dinov2.py",
|
||||||
"tests/weights/dinov2_vitl14_pretrain.pth",
|
"tests/weights/dinov2_vitl14_pretrain.pth",
|
||||||
"tests/weights/dinov2_vitl14_pretrain.safetensors",
|
"tests/weights/dinov2_vitl14_pretrain.safetensors",
|
||||||
expected_hash="71eb98d1",
|
expected_hash="ddd4819f",
|
||||||
)
|
)
|
||||||
run_conversion_script(
|
run_conversion_script(
|
||||||
"convert_dinov2.py",
|
"convert_dinov2.py",
|
||||||
"tests/weights/dinov2_vits14_reg4_pretrain.pth",
|
"tests/weights/dinov2_vits14_reg4_pretrain.pth",
|
||||||
"tests/weights/dinov2_vits14_reg4_pretrain.safetensors",
|
"tests/weights/dinov2_vits14_reg4_pretrain.safetensors",
|
||||||
expected_hash="89118b46",
|
expected_hash="080247c7",
|
||||||
)
|
)
|
||||||
run_conversion_script(
|
run_conversion_script(
|
||||||
"convert_dinov2.py",
|
"convert_dinov2.py",
|
||||||
"tests/weights/dinov2_vitb14_reg4_pretrain.pth",
|
"tests/weights/dinov2_vitb14_reg4_pretrain.pth",
|
||||||
"tests/weights/dinov2_vitb14_reg4_pretrain.safetensors",
|
"tests/weights/dinov2_vitb14_reg4_pretrain.safetensors",
|
||||||
expected_hash="b0296f77",
|
expected_hash="5cd4d408",
|
||||||
)
|
)
|
||||||
run_conversion_script(
|
run_conversion_script(
|
||||||
"convert_dinov2.py",
|
"convert_dinov2.py",
|
||||||
"tests/weights/dinov2_vitl14_reg4_pretrain.pth",
|
"tests/weights/dinov2_vitl14_reg4_pretrain.pth",
|
||||||
"tests/weights/dinov2_vitl14_reg4_pretrain.safetensors",
|
"tests/weights/dinov2_vitl14_reg4_pretrain.safetensors",
|
||||||
expected_hash="b3d877dc",
|
expected_hash="b1221702",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -146,6 +146,7 @@ class DINOv2_small_reg(ViT):
|
||||||
num_layers (int): 12
|
num_layers (int): 12
|
||||||
num_heads (int): 6
|
num_heads (int): 6
|
||||||
num_registers (int): 4
|
num_registers (int): 4
|
||||||
|
interpolate_antialias (bool): True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -166,6 +167,7 @@ class DINOv2_small_reg(ViT):
|
||||||
num_layers=12,
|
num_layers=12,
|
||||||
num_heads=6,
|
num_heads=6,
|
||||||
num_registers=4,
|
num_registers=4,
|
||||||
|
interpolate_antialias=True,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -185,6 +187,7 @@ class DINOv2_base_reg(ViT):
|
||||||
num_layers (int): 12
|
num_layers (int): 12
|
||||||
num_heads (int): 12
|
num_heads (int): 12
|
||||||
num_registers (int): 4
|
num_registers (int): 4
|
||||||
|
interpolate_antialias (bool): True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -205,6 +208,7 @@ class DINOv2_base_reg(ViT):
|
||||||
num_layers=12,
|
num_layers=12,
|
||||||
num_heads=12,
|
num_heads=12,
|
||||||
num_registers=4,
|
num_registers=4,
|
||||||
|
interpolate_antialias=True,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -224,6 +228,7 @@ class DINOv2_large_reg(ViT):
|
||||||
num_layers (int): 24
|
num_layers (int): 24
|
||||||
num_heads (int): 16
|
num_heads (int): 16
|
||||||
num_registers (int): 4
|
num_registers (int): 4
|
||||||
|
interpolate_antialias (bool): True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -244,6 +249,7 @@ class DINOv2_large_reg(ViT):
|
||||||
num_layers=24,
|
num_layers=24,
|
||||||
num_heads=16,
|
num_heads=16,
|
||||||
num_registers=4,
|
num_registers=4,
|
||||||
|
interpolate_antialias=True,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -263,6 +269,7 @@ class DINOv2_large_reg(ViT):
|
||||||
# num_layers=40,
|
# num_layers=40,
|
||||||
# num_heads=24,
|
# num_heads=24,
|
||||||
# num_registers=4,
|
# num_registers=4,
|
||||||
|
# interpolate_antialias=True,
|
||||||
# device=device,
|
# device=device,
|
||||||
# dtype=dtype,
|
# dtype=dtype,
|
||||||
# )
|
# )
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
|
from math import sqrt
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.fluxion.context import Contexts
|
||||||
from refiners.fluxion.layers.activations import Activation
|
from refiners.fluxion.layers.activations import Activation
|
||||||
|
from refiners.fluxion.utils import interpolate
|
||||||
|
|
||||||
|
|
||||||
class ClassToken(fl.Chain):
|
class ClassToken(fl.Chain):
|
||||||
|
@ -27,18 +30,20 @@ class ClassToken(fl.Chain):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PositionalEncoder(fl.Residual):
|
class PositionalEmbedding(fl.Chain):
|
||||||
"""Encode the position of each patch in the input."""
|
"""Learnable positional embedding."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sequence_length: int,
|
sequence_length: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
|
patch_size: int,
|
||||||
device: torch.device | str | None = None,
|
device: torch.device | str | None = None,
|
||||||
dtype: torch.dtype | None = None,
|
dtype: torch.dtype | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_patches = sequence_length
|
self.sequence_length = sequence_length
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Parameter(
|
fl.Parameter(
|
||||||
|
@ -49,6 +54,55 @@ class PositionalEncoder(fl.Residual):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InterpolateEmbedding(fl.Module):
|
||||||
|
"""Interpolate the positional embeddings to match the input shape."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mode: str,
|
||||||
|
antialias: bool,
|
||||||
|
patch_size: int,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.mode = mode
|
||||||
|
self.antialias = antialias
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
input: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
cls_embed = x[:, :1, :] # -> (1, 1, D)
|
||||||
|
patch_embed = x[:, 1:, :] # -> (1, N, D)
|
||||||
|
|
||||||
|
N = patch_embed.shape[1]
|
||||||
|
D = patch_embed.shape[2]
|
||||||
|
M = int(sqrt(N))
|
||||||
|
W = input.shape[2]
|
||||||
|
H = input.shape[3]
|
||||||
|
assert M * M == N, "The sequence length must be a square number."
|
||||||
|
|
||||||
|
patch_embed = patch_embed.reshape(1, M, M, D) # -> (1, M, M, D)
|
||||||
|
patch_embed = patch_embed.permute(0, 3, 1, 2) # -> (1, D, M, M)
|
||||||
|
patch_embed = interpolate(
|
||||||
|
x=patch_embed.to(dtype=torch.float32),
|
||||||
|
mode=self.mode,
|
||||||
|
antialias=self.antialias,
|
||||||
|
size=torch.Size(
|
||||||
|
(
|
||||||
|
W // self.patch_size,
|
||||||
|
H // self.patch_size,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
).to(dtype=cls_embed.dtype) # -> (1, D, w, h)
|
||||||
|
patch_embed = patch_embed.permute(0, 2, 3, 1) # -> (1, w, h, D)
|
||||||
|
patch_embed = patch_embed.reshape(1, -1, D) # -> (1, w*h, D)
|
||||||
|
|
||||||
|
x = torch.cat((cls_embed, patch_embed), dim=1) # -> (1, w*h+1, D)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LayerScale(fl.WeightedModule):
|
class LayerScale(fl.WeightedModule):
|
||||||
"""Scale the input tensor by a learnable parameter."""
|
"""Scale the input tensor by a learnable parameter."""
|
||||||
|
|
||||||
|
@ -125,6 +179,7 @@ class PatchEncoder(fl.Chain):
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
fl.SetContext(context="dinov2_vit", key="input"), # save the original input
|
||||||
fl.Conv2d(
|
fl.Conv2d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
|
@ -201,6 +256,10 @@ class Transformer(fl.Chain):
|
||||||
"""Alias for a Chain of TransformerLayer."""
|
"""Alias for a Chain of TransformerLayer."""
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoder(fl.Residual):
|
||||||
|
"""Alias for a Residual."""
|
||||||
|
|
||||||
|
|
||||||
class Registers(fl.Concatenate):
|
class Registers(fl.Concatenate):
|
||||||
"""Insert register tokens between CLS token and patches."""
|
"""Insert register tokens between CLS token and patches."""
|
||||||
|
|
||||||
|
@ -243,6 +302,8 @@ class ViT(fl.Chain):
|
||||||
norm_eps: float = 1e-6,
|
norm_eps: float = 1e-6,
|
||||||
mlp_ratio: int = 4,
|
mlp_ratio: int = 4,
|
||||||
num_registers: int = 0,
|
num_registers: int = 0,
|
||||||
|
interpolate_antialias: bool = False,
|
||||||
|
interpolate_mode: str = "bicubic",
|
||||||
device: torch.device | str | None = None,
|
device: torch.device | str | None = None,
|
||||||
dtype: torch.dtype | None = None,
|
dtype: torch.dtype | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -257,6 +318,8 @@ class ViT(fl.Chain):
|
||||||
norm_eps: The epsilon value for normalization.
|
norm_eps: The epsilon value for normalization.
|
||||||
mlp_ratio: The ratio for the multi-layer perceptron (MLP).
|
mlp_ratio: The ratio for the multi-layer perceptron (MLP).
|
||||||
num_registers: The number of registers.
|
num_registers: The number of registers.
|
||||||
|
interpolate_antialias: Whether to use antialiasing for interpolation.
|
||||||
|
interpolate_mode: The interpolation mode.
|
||||||
device: The PyTorch device to use.
|
device: The PyTorch device to use.
|
||||||
dtype: The PyTorch data type to use.
|
dtype: The PyTorch data type to use.
|
||||||
"""
|
"""
|
||||||
|
@ -286,19 +349,32 @@ class ViT(fl.Chain):
|
||||||
),
|
),
|
||||||
dim=1,
|
dim=1,
|
||||||
),
|
),
|
||||||
# TODO: support https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179
|
|
||||||
PositionalEncoder(
|
PositionalEncoder(
|
||||||
|
PositionalEmbedding(
|
||||||
sequence_length=num_patches**2 + 1,
|
sequence_length=num_patches**2 + 1,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
|
patch_size=patch_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.Parallel(
|
||||||
|
fl.Identity(),
|
||||||
|
fl.UseContext(context="dinov2_vit", key="input"),
|
||||||
|
),
|
||||||
|
InterpolateEmbedding(
|
||||||
|
mode=interpolate_mode,
|
||||||
|
antialias=interpolate_antialias,
|
||||||
|
patch_size=patch_size,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
Transformer(
|
Transformer(
|
||||||
TransformerLayer(
|
TransformerLayer(
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
norm_eps=norm_eps,
|
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
|
norm_eps=norm_eps,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -320,3 +396,10 @@ class ViT(fl.Chain):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self.insert_before_type(Transformer, registers)
|
self.insert_before_type(Transformer, registers)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {
|
||||||
|
"dinov2_vit": {
|
||||||
|
"input": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue