refactor dinov2 tests, check against official implementation

This commit is contained in:
Laurent 2024-03-29 17:42:42 +00:00 committed by Laureηt
parent 4f94dfb494
commit 1a8ea9180f
5 changed files with 118 additions and 116 deletions

1
.gitignore vendored
View file

@ -11,6 +11,7 @@ venv/
# tests' model weights
tests/weights/
tests/repos/
# ruff
.ruff_cache

View file

@ -52,6 +52,12 @@ Then, download and convert all the necessary weights. Be aware that this will us
python scripts/prepare_test_weights.py
```
Some tests require cloning the original implementation of the model as they use `torch.hub.load`:
```bash
git clone git@github.com:facebookresearch/dinov2.git tests/repos/dinov2
```
Finally, run the tests:
```bash

View file

@ -388,16 +388,6 @@ def download_dinov2():
]
download_files(urls, weights_folder)
# For testing (note: versions with registers are not available yet on HuggingFace)
for repo in ["dinov2-small", "dinov2-base", "dinov2-large"]:
base_folder = os.path.join(test_weights_dir, "facebook", repo)
urls = [
f"https://huggingface.co/facebook/{repo}/raw/main/config.json",
f"https://huggingface.co/facebook/{repo}/raw/main/preprocessor_config.json",
f"https://huggingface.co/facebook/{repo}/resolve/main/pytorch_model.bin",
]
download_files(urls, base_folder)
def download_lcm_base():
base_folder = os.path.join(test_weights_dir, "latent-consistency/lcm-sdxl")

View file

@ -21,6 +21,12 @@ def test_weights_path() -> Path:
return Path(from_env) if from_env else PARENT_PATH / "weights"
@fixture(scope="session")
def test_repos_path() -> Path:
from_env = os.getenv("REFINERS_TEST_REPOS_DIR")
return Path(from_env) if from_env else PARENT_PATH / "repos"
@fixture(scope="session")
def test_e2e_path() -> Path:
return PARENT_PATH / "e2e"

View file

@ -1,14 +1,12 @@
from math import isclose
from pathlib import Path
from typing import Any
from warnings import warn
import pytest
import torch
from transformers import AutoModel # type: ignore
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
from refiners.foundationals.dinov2 import (
from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_seed, no_grad
from refiners.foundationals.dinov2.dinov2 import (
DINOv2_base,
DINOv2_base_reg,
DINOv2_large,
@ -18,130 +16,131 @@ from refiners.foundationals.dinov2 import (
)
from refiners.foundationals.dinov2.vit import ViT
FLAVORS = [
"dinov2_vits14",
"dinov2_vitb14",
"dinov2_vitl14",
"dinov2_vits14_reg4",
"dinov2_vitb14_reg4",
"dinov2_vitl14_reg4",
]
FLAVORS_MAP = {
"dinov2_vits14": DINOv2_small,
"dinov2_vits14_reg": DINOv2_small_reg,
"dinov2_vitb14": DINOv2_base,
"dinov2_vitb14_reg": DINOv2_base_reg,
"dinov2_vitl14": DINOv2_large,
"dinov2_vitl14_reg": DINOv2_large_reg,
# TODO: support giant flavors
# "dinov2_vitg14": DINOv2_giant,
# "dinov2_vitg14_reg": DINOv2_giant_reg,
}
@pytest.fixture(scope="module", params=FLAVORS)
@pytest.fixture(scope="module", params=[224, 518])
def resolution(request: pytest.FixtureRequest) -> int:
return request.param
@pytest.fixture(scope="module", params=FLAVORS_MAP.keys())
def flavor(request: pytest.FixtureRequest) -> str:
return request.param
# Temporary: see comments in `test_encoder_only`
@pytest.fixture(scope="module")
def seed_expected_norm(flavor: str) -> tuple[int, float]:
match flavor:
case "dinov2_vits14":
return (42, 1977.9213867)
case "dinov2_vitb14":
return (42, 1902.6384277)
case "dinov2_vitl14":
return (42, 1763.9187011)
case "dinov2_vits14_reg4":
return (42, 989.2380981)
case "dinov2_vitb14_reg4":
return (42, 974.4362182)
case "dinov2_vitl14_reg4":
return (42, 924.8797607)
case _:
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
def dinov2_repo_path(test_repos_path: Path) -> Path:
repo = test_repos_path / "dinov2"
if not repo.exists():
warn(f"could not find DINOv2 GitHub repo at {repo}, skipping")
pytest.skip(allow_module_level=True)
return repo
@pytest.fixture(scope="module")
def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device) -> ViT:
def ref_model(
flavor: str,
dinov2_repo_path: Path,
test_weights_path: Path,
test_device: torch.device,
) -> torch.nn.Module:
kwargs: dict[str, Any] = {}
if "reg" not in flavor:
kwargs["interpolate_offset"] = 0.0
model = torch.hub.load( # type: ignore
model=flavor,
repo_or_dir=str(dinov2_repo_path),
source="local",
pretrained=False, # to turn off automatic weights download (see load_state_dict below)
**kwargs,
).to(device=test_device)
flavor = flavor.replace("_reg", "_reg4")
weights = test_weights_path / f"{flavor}_pretrain.pth"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
model.load_state_dict(load_tensors(weights, device=test_device))
assert isinstance(model, torch.nn.Module)
return model
@pytest.fixture(scope="module")
def our_model(
test_weights_path: Path,
flavor: str,
test_device: torch.device,
) -> ViT:
model = FLAVORS_MAP[flavor](device=test_device)
flavor = flavor.replace("_reg", "_reg4")
weights = test_weights_path / f"{flavor}_pretrain.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
match flavor:
case "dinov2_vits14":
backbone = DINOv2_small(device=test_device)
case "dinov2_vitb14":
backbone = DINOv2_base(device=test_device)
case "dinov2_vitl14":
backbone = DINOv2_large(device=test_device)
case "dinov2_vits14_reg4":
backbone = DINOv2_small_reg(device=test_device)
case "dinov2_vitb14_reg4":
backbone = DINOv2_base_reg(device=test_device)
case "dinov2_vitl14_reg4":
backbone = DINOv2_large_reg(device=test_device)
case _:
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
tensors = load_from_safetensors(weights)
backbone.load_state_dict(tensors)
return backbone
model.load_state_dict(tensors)
return model
@pytest.fixture(scope="module")
def dinov2_weights_path(test_weights_path: Path, flavor: str):
# TODO: At the time of writing, those are not yet supported in transformers
# (https://github.com/huggingface/transformers/issues/27379). Alternatively, it is also possible to use
# facebookresearch/dinov2 directly (https://github.com/finegrain-ai/refiners/pull/132).
if flavor.endswith("_reg4"):
warn(f"DINOv2 with registers are not yet supported in Hugging Face, skipping")
pytest.skip(allow_module_level=True)
match flavor:
case "dinov2_vits14":
name = "dinov2-small"
case "dinov2_vitb14":
name = "dinov2-base"
case "dinov2_vitl14":
name = "dinov2-large"
case _:
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
r = test_weights_path / "facebook" / name
if not r.is_dir():
warn(f"could not find DINOv2 weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def ref_backbone(dinov2_weights_path: Path, test_device: torch.device) -> Dinov2Model:
backbone = AutoModel.from_pretrained(dinov2_weights_path) # type: ignore
assert isinstance(backbone, Dinov2Model)
return backbone.to(test_device) # type: ignore
def test_encoder(
ref_backbone: Dinov2Model,
our_backbone: ViT,
@no_grad()
def test_dinov2_facebook_weights(
ref_model: torch.nn.Module,
our_model: ViT,
resolution: int,
test_device: torch.device,
):
manual_seed(42)
) -> None:
manual_seed(2)
input_data = torch.randn(
(1, 3, resolution, resolution),
device=test_device,
)
# Position encoding interpolation [1] at runtime is not supported yet. So stick to the default image resolution
# e.g. using (224, 224) pixels as input would give a runtime error (sequence size mismatch)
# [1]: https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179
assert our_backbone.image_size == 518
ref_output = ref_model(input_data, is_training=True)
ref_cls = ref_output["x_norm_clstoken"]
ref_reg = ref_output["x_norm_regtokens"]
ref_patch = ref_output["x_norm_patchtokens"]
x = torch.randn(1, 3, 518, 518).to(test_device)
our_output = our_model(input_data)
our_cls = our_output[:, 0]
our_reg = our_output[:, 1 : our_model.num_registers + 1]
our_patch = our_output[:, our_model.num_registers + 1 :]
with no_grad():
ref_features = ref_backbone(x).last_hidden_state
our_features = our_backbone(x)
assert (our_features - ref_features).abs().max() < 1e-3
assert torch.allclose(ref_cls, our_cls, atol=1e-4)
assert torch.allclose(ref_reg, our_reg, atol=1e-4)
assert torch.allclose(ref_patch, our_patch, atol=3e-3)
# Mainly for DINOv2 + registers coverage (this test can be removed once `test_encoder` supports all flavors)
def test_encoder_only(
our_backbone: ViT,
seed_expected_norm: tuple[int, float],
@no_grad()
def test_dinov2_float16(
resolution: int,
test_device: torch.device,
):
seed, expected_norm = seed_expected_norm
manual_seed(seed)
) -> None:
model = DINOv2_small(device=test_device, dtype=torch.float16)
x = torch.randn(1, 3, 518, 518).to(test_device)
manual_seed(2)
input_data = torch.randn(
(1, 3, resolution, resolution),
device=test_device,
dtype=torch.float16,
)
our_features = our_backbone(x)
assert isclose(our_features.norm().item(), expected_norm, rel_tol=1e-04)
output = model(input_data)
sequence_length = (resolution // model.patch_size) ** 2 + 1
assert output.shape == (1, sequence_length, model.embedding_dim)
assert output.dtype == torch.float16