refiners/tests/foundationals/dinov2/test_dinov2.py
Pierre Chapuis 73089a4e2d
Some checks failed
CI / lint_and_typecheck (push) Has been cancelled
Deploy docs to GitHub Pages / Deploy docs (push) Has been cancelled
Spell checker / Spell check (push) Has been cancelled
fix typing issue in dinov2 test
2024-09-18 10:44:24 +02:00

172 lines
4.7 KiB
Python

from pathlib import Path
from typing import Any
from warnings import warn
import pytest
import torch
from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_seed, no_grad
from refiners.foundationals.dinov2.dinov2 import (
DINOv2_base,
DINOv2_base_reg,
DINOv2_giant,
DINOv2_giant_reg,
DINOv2_large,
DINOv2_large_reg,
DINOv2_small,
DINOv2_small_reg,
)
from refiners.foundationals.dinov2.vit import ViT
FLAVORS_MAP = {
"dinov2_vits14": DINOv2_small,
"dinov2_vits14_reg": DINOv2_small_reg,
"dinov2_vitb14": DINOv2_base,
"dinov2_vitb14_reg": DINOv2_base_reg,
"dinov2_vitl14": DINOv2_large,
"dinov2_vitl14_reg": DINOv2_large_reg,
"dinov2_vitg14": DINOv2_giant,
"dinov2_vitg14_reg": DINOv2_giant_reg,
}
@pytest.fixture(scope="module", params=[224, 518])
def resolution(request: pytest.FixtureRequest) -> int:
return request.param
@pytest.fixture(scope="module", params=FLAVORS_MAP.keys())
def flavor(request: pytest.FixtureRequest) -> str:
return request.param
@pytest.fixture(scope="module")
def dinov2_repo_path(test_repos_path: Path) -> Path:
repo = test_repos_path / "dinov2"
if not repo.exists():
warn(f"could not find DINOv2 GitHub repo at {repo}, skipping")
pytest.skip(allow_module_level=True)
return repo
@pytest.fixture(scope="module")
def ref_model(
flavor: str,
dinov2_repo_path: Path,
test_weights_path: Path,
test_device: torch.device,
) -> torch.nn.Module:
kwargs: dict[str, Any] = {}
if "reg" not in flavor:
kwargs["interpolate_offset"] = 0.0
model: torch.nn.Module = torch.hub.load( # type: ignore
model=flavor,
repo_or_dir=str(dinov2_repo_path),
source="local",
pretrained=False, # to turn off automatic weights download (see load_state_dict below)
**kwargs,
)
model = model.to(device=test_device)
flavor = flavor.replace("_reg", "_reg4")
weights = test_weights_path / f"{flavor}_pretrain.pth"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
model.load_state_dict(load_tensors(weights, device=test_device))
assert isinstance(model, torch.nn.Module)
return model
@pytest.fixture(scope="module")
def our_model(
test_weights_path: Path,
flavor: str,
test_device: torch.device,
) -> ViT:
model = FLAVORS_MAP[flavor](device=test_device)
flavor = flavor.replace("_reg", "_reg4")
weights = test_weights_path / f"{flavor}_pretrain.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
tensors = load_from_safetensors(weights)
model.load_state_dict(tensors)
return model
@no_grad()
def test_dinov2_facebook_weights(
ref_model: torch.nn.Module,
our_model: ViT,
resolution: int,
test_device: torch.device,
) -> None:
manual_seed(2)
input_data = torch.randn(
(1, 3, resolution, resolution),
device=test_device,
)
ref_output = ref_model(input_data, is_training=True)
ref_cls = ref_output["x_norm_clstoken"]
ref_reg = ref_output["x_norm_regtokens"]
ref_patch = ref_output["x_norm_patchtokens"]
our_output = our_model(input_data)
our_cls = our_output[:, 0]
our_reg = our_output[:, 1 : our_model.num_registers + 1]
our_patch = our_output[:, our_model.num_registers + 1 :]
assert torch.allclose(ref_cls, our_cls, atol=1e-4)
assert torch.allclose(ref_reg, our_reg, atol=1e-4)
assert torch.allclose(ref_patch, our_patch, atol=3e-3)
@no_grad()
def test_dinov2_float16(
resolution: int,
test_device: torch.device,
) -> None:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
model = DINOv2_small(device=test_device, dtype=torch.float16)
manual_seed(2)
input_data = torch.randn(
(1, 3, resolution, resolution),
device=test_device,
dtype=torch.float16,
)
output = model(input_data)
sequence_length = (resolution // model.patch_size) ** 2 + 1
assert output.shape == (1, sequence_length, model.embedding_dim)
assert output.dtype == torch.float16
@no_grad()
def test_dinov2_batch_size(
resolution: int,
test_device: torch.device,
) -> None:
model = DINOv2_small(device=test_device)
batch_size = 4
manual_seed(2)
input_data = torch.randn(
(batch_size, 3, resolution, resolution),
device=test_device,
)
output = model(input_data)
sequence_length = (resolution // model.patch_size) ** 2 + 1
assert output.shape == (batch_size, sequence_length, model.embedding_dim)