refiners/tests/foundationals/dinov2/test_dinov2.py

148 lines
4.9 KiB
Python
Raw Normal View History

from math import isclose
from pathlib import Path
from warnings import warn
import pytest
import torch
from transformers import AutoModel # type: ignore
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore
from refiners.fluxion.utils import load_from_safetensors, manual_seed
from refiners.foundationals.dinov2 import (
DINOv2_base,
DINOv2_base_reg,
DINOv2_large,
DINOv2_large_reg,
DINOv2_small,
DINOv2_small_reg,
)
from refiners.foundationals.dinov2.vit import ViT
FLAVORS = [
"dinov2_vits14",
"dinov2_vitb14",
"dinov2_vitl14",
"dinov2_vits14_reg4",
"dinov2_vitb14_reg4",
"dinov2_vitl14_reg4",
]
@pytest.fixture(scope="module", params=FLAVORS)
def flavor(request: pytest.FixtureRequest) -> str:
return request.param
# Temporary: see comments in `test_encoder_only`
@pytest.fixture(scope="module")
def seed_expected_norm(flavor: str) -> tuple[int, float]:
match flavor:
case "dinov2_vits14":
return (42, 1977.9213867)
case "dinov2_vitb14":
return (42, 1902.6384277)
case "dinov2_vitl14":
return (42, 1763.9187011)
case "dinov2_vits14_reg4":
return (42, 989.2380981)
case "dinov2_vitb14_reg4":
return (42, 974.4362182)
case "dinov2_vitl14_reg4":
return (42, 924.8797607)
case _:
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
@pytest.fixture(scope="module")
def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device) -> ViT:
weights = test_weights_path / f"{flavor}_pretrain.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
match flavor:
case "dinov2_vits14":
backbone = DINOv2_small(device=test_device)
case "dinov2_vitb14":
backbone = DINOv2_base(device=test_device)
case "dinov2_vitl14":
backbone = DINOv2_large(device=test_device)
case "dinov2_vits14_reg4":
backbone = DINOv2_small_reg(device=test_device)
case "dinov2_vitb14_reg4":
backbone = DINOv2_base_reg(device=test_device)
case "dinov2_vitl14_reg4":
backbone = DINOv2_large_reg(device=test_device)
case _:
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
tensors = load_from_safetensors(weights)
backbone.load_state_dict(tensors)
return backbone
@pytest.fixture(scope="module")
def dinov2_weights_path(test_weights_path: Path, flavor: str):
# TODO: At the time of writing, those are not yet supported in transformers
# (https://github.com/huggingface/transformers/issues/27379). Alternatively, it is also possible to use
# facebookresearch/dinov2 directly (https://github.com/finegrain-ai/refiners/pull/132).
if flavor.endswith("_reg4"):
warn(f"DINOv2 with registers are not yet supported in Hugging Face, skipping")
pytest.skip(allow_module_level=True)
match flavor:
case "dinov2_vits14":
name = "dinov2-small"
case "dinov2_vitb14":
name = "dinov2-base"
case "dinov2_vitl14":
name = "dinov2-large"
case _:
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
r = test_weights_path / "facebook" / name
if not r.is_dir():
warn(f"could not find DINOv2 weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def ref_backbone(dinov2_weights_path: Path, test_device: torch.device) -> Dinov2Model:
backbone = AutoModel.from_pretrained(dinov2_weights_path) # type: ignore
assert isinstance(backbone, Dinov2Model)
return backbone.to(test_device) # type: ignore
def test_encoder(
ref_backbone: Dinov2Model,
our_backbone: ViT,
test_device: torch.device,
):
manual_seed(42)
# Position encoding interpolation [1] at runtime is not supported yet. So stick to the default image resolution
# e.g. using (224, 224) pixels as input would give a runtime error (sequence size mismatch)
# [1]: https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179
assert our_backbone.image_size == 518
x = torch.randn(1, 3, 518, 518).to(test_device)
with torch.no_grad():
ref_features = ref_backbone(x).last_hidden_state
our_features = our_backbone(x)
assert (our_features - ref_features).abs().max() < 1e-3
# Mainly for DINOv2 + registers coverage (this test can be removed once `test_encoder` supports all flavors)
def test_encoder_only(
our_backbone: ViT,
seed_expected_norm: tuple[int, float],
test_device: torch.device,
):
seed, expected_norm = seed_expected_norm
manual_seed(seed)
x = torch.randn(1, 3, 518, 518).to(test_device)
our_features = our_backbone(x)
assert isclose(our_features.norm().item(), expected_norm, rel_tol=1e-04)