mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-13 00:28:14 +00:00
dinov2: add some coverage for registers
Those are not supported yet in HF: so just compared with a precomputed norm. Note: in the initial PR [1] the Refiners' implementation has been tested against the official code using Torch Hub. [1]: https://github.com/finegrain-ai/refiners/pull/132#issuecomment-1852021656
This commit is contained in:
parent
f0ea1a2509
commit
e7892254eb
|
@ -1,3 +1,4 @@
|
||||||
|
from math import isclose
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
|
@ -7,16 +8,23 @@ from transformers import AutoModel # type: ignore
|
||||||
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore
|
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore
|
||||||
|
|
||||||
from refiners.fluxion.utils import load_from_safetensors, manual_seed
|
from refiners.fluxion.utils import load_from_safetensors, manual_seed
|
||||||
from refiners.foundationals.dinov2 import DINOv2_base, DINOv2_large, DINOv2_small
|
from refiners.foundationals.dinov2 import (
|
||||||
|
DINOv2_base,
|
||||||
|
DINOv2_base_reg,
|
||||||
|
DINOv2_large,
|
||||||
|
DINOv2_large_reg,
|
||||||
|
DINOv2_small,
|
||||||
|
DINOv2_small_reg,
|
||||||
|
)
|
||||||
from refiners.foundationals.dinov2.vit import ViT
|
from refiners.foundationals.dinov2.vit import ViT
|
||||||
|
|
||||||
# TODO: add DINOv2 with registers ("dinov2_vits14_reg", etc). At the time of writing, those are not yet supported in
|
|
||||||
# transformers (https://github.com/huggingface/transformers/issues/27379). Alternatively, it is also possible to use
|
|
||||||
# facebookresearch/dinov2 directly (https://github.com/finegrain-ai/refiners/pull/132).
|
|
||||||
FLAVORS = [
|
FLAVORS = [
|
||||||
"dinov2_vits14",
|
"dinov2_vits14",
|
||||||
"dinov2_vitb14",
|
"dinov2_vitb14",
|
||||||
"dinov2_vitl14",
|
"dinov2_vitl14",
|
||||||
|
"dinov2_vits14_reg4",
|
||||||
|
"dinov2_vitb14_reg4",
|
||||||
|
"dinov2_vitl14_reg4",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,6 +33,26 @@ def flavor(request: pytest.FixtureRequest) -> str:
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
# Temporary: see comments in `test_encoder_only`
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def seed_expected_norm(flavor: str) -> tuple[int, float]:
|
||||||
|
match flavor:
|
||||||
|
case "dinov2_vits14":
|
||||||
|
return (42, 1977.9213867)
|
||||||
|
case "dinov2_vitb14":
|
||||||
|
return (42, 1902.6384277)
|
||||||
|
case "dinov2_vitl14":
|
||||||
|
return (42, 1763.9187011)
|
||||||
|
case "dinov2_vits14_reg4":
|
||||||
|
return (42, 989.2380981)
|
||||||
|
case "dinov2_vitb14_reg4":
|
||||||
|
return (42, 974.4362182)
|
||||||
|
case "dinov2_vitl14_reg4":
|
||||||
|
return (42, 924.8797607)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device) -> ViT:
|
def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device) -> ViT:
|
||||||
weights = test_weights_path / f"{flavor}_pretrain.safetensors"
|
weights = test_weights_path / f"{flavor}_pretrain.safetensors"
|
||||||
|
@ -38,6 +66,12 @@ def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device
|
||||||
backbone = DINOv2_base(device=test_device)
|
backbone = DINOv2_base(device=test_device)
|
||||||
case "dinov2_vitl14":
|
case "dinov2_vitl14":
|
||||||
backbone = DINOv2_large(device=test_device)
|
backbone = DINOv2_large(device=test_device)
|
||||||
|
case "dinov2_vits14_reg4":
|
||||||
|
backbone = DINOv2_small_reg(device=test_device)
|
||||||
|
case "dinov2_vitb14_reg4":
|
||||||
|
backbone = DINOv2_base_reg(device=test_device)
|
||||||
|
case "dinov2_vitl14_reg4":
|
||||||
|
backbone = DINOv2_large_reg(device=test_device)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
|
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
|
||||||
tensors = load_from_safetensors(weights)
|
tensors = load_from_safetensors(weights)
|
||||||
|
@ -47,6 +81,12 @@ def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def dinov2_weights_path(test_weights_path: Path, flavor: str):
|
def dinov2_weights_path(test_weights_path: Path, flavor: str):
|
||||||
|
# TODO: At the time of writing, those are not yet supported in transformers
|
||||||
|
# (https://github.com/huggingface/transformers/issues/27379). Alternatively, it is also possible to use
|
||||||
|
# facebookresearch/dinov2 directly (https://github.com/finegrain-ai/refiners/pull/132).
|
||||||
|
if flavor.endswith("_reg4"):
|
||||||
|
warn(f"DINOv2 with registers are not yet supported in Hugging Face, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
match flavor:
|
match flavor:
|
||||||
case "dinov2_vits14":
|
case "dinov2_vits14":
|
||||||
name = "dinov2-small"
|
name = "dinov2-small"
|
||||||
|
@ -89,3 +129,19 @@ def test_encoder(
|
||||||
our_features = our_backbone(x)
|
our_features = our_backbone(x)
|
||||||
|
|
||||||
assert (our_features - ref_features).abs().max() < 1e-3
|
assert (our_features - ref_features).abs().max() < 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
# Mainly for DINOv2 + registers coverage (this test can be removed once `test_encoder` supports all flavors)
|
||||||
|
def test_encoder_only(
|
||||||
|
our_backbone: ViT,
|
||||||
|
seed_expected_norm: tuple[int, float],
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
seed, expected_norm = seed_expected_norm
|
||||||
|
manual_seed(seed)
|
||||||
|
|
||||||
|
x = torch.randn(1, 3, 518, 518).to(test_device)
|
||||||
|
|
||||||
|
our_features = our_backbone(x)
|
||||||
|
|
||||||
|
assert isclose(our_features.norm().item(), expected_norm, rel_tol=1e-04)
|
||||||
|
|
Loading…
Reference in a new issue