refactor dinov2 tests, check against official implementation

This commit is contained in:
Laurent 2024-03-29 17:42:42 +00:00 committed by Laureηt
parent 4f94dfb494
commit 1a8ea9180f
5 changed files with 118 additions and 116 deletions

1
.gitignore vendored
View file

@ -11,6 +11,7 @@ venv/
# tests' model weights # tests' model weights
tests/weights/ tests/weights/
tests/repos/
# ruff # ruff
.ruff_cache .ruff_cache

View file

@ -52,6 +52,12 @@ Then, download and convert all the necessary weights. Be aware that this will us
python scripts/prepare_test_weights.py python scripts/prepare_test_weights.py
``` ```
Some tests require cloning the original implementation of the model as they use `torch.hub.load`:
```bash
git clone git@github.com:facebookresearch/dinov2.git tests/repos/dinov2
```
Finally, run the tests: Finally, run the tests:
```bash ```bash

View file

@ -388,16 +388,6 @@ def download_dinov2():
] ]
download_files(urls, weights_folder) download_files(urls, weights_folder)
# For testing (note: versions with registers are not available yet on HuggingFace)
for repo in ["dinov2-small", "dinov2-base", "dinov2-large"]:
base_folder = os.path.join(test_weights_dir, "facebook", repo)
urls = [
f"https://huggingface.co/facebook/{repo}/raw/main/config.json",
f"https://huggingface.co/facebook/{repo}/raw/main/preprocessor_config.json",
f"https://huggingface.co/facebook/{repo}/resolve/main/pytorch_model.bin",
]
download_files(urls, base_folder)
def download_lcm_base(): def download_lcm_base():
base_folder = os.path.join(test_weights_dir, "latent-consistency/lcm-sdxl") base_folder = os.path.join(test_weights_dir, "latent-consistency/lcm-sdxl")

View file

@ -21,6 +21,12 @@ def test_weights_path() -> Path:
return Path(from_env) if from_env else PARENT_PATH / "weights" return Path(from_env) if from_env else PARENT_PATH / "weights"
@fixture(scope="session")
def test_repos_path() -> Path:
from_env = os.getenv("REFINERS_TEST_REPOS_DIR")
return Path(from_env) if from_env else PARENT_PATH / "repos"
@fixture(scope="session") @fixture(scope="session")
def test_e2e_path() -> Path: def test_e2e_path() -> Path:
return PARENT_PATH / "e2e" return PARENT_PATH / "e2e"

View file

@ -1,14 +1,12 @@
from math import isclose
from pathlib import Path from pathlib import Path
from typing import Any
from warnings import warn from warnings import warn
import pytest import pytest
import torch import torch
from transformers import AutoModel # type: ignore
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_seed, no_grad
from refiners.foundationals.dinov2 import ( from refiners.foundationals.dinov2.dinov2 import (
DINOv2_base, DINOv2_base,
DINOv2_base_reg, DINOv2_base_reg,
DINOv2_large, DINOv2_large,
@ -18,130 +16,131 @@ from refiners.foundationals.dinov2 import (
) )
from refiners.foundationals.dinov2.vit import ViT from refiners.foundationals.dinov2.vit import ViT
FLAVORS = [ FLAVORS_MAP = {
"dinov2_vits14", "dinov2_vits14": DINOv2_small,
"dinov2_vitb14", "dinov2_vits14_reg": DINOv2_small_reg,
"dinov2_vitl14", "dinov2_vitb14": DINOv2_base,
"dinov2_vits14_reg4", "dinov2_vitb14_reg": DINOv2_base_reg,
"dinov2_vitb14_reg4", "dinov2_vitl14": DINOv2_large,
"dinov2_vitl14_reg4", "dinov2_vitl14_reg": DINOv2_large_reg,
] # TODO: support giant flavors
# "dinov2_vitg14": DINOv2_giant,
# "dinov2_vitg14_reg": DINOv2_giant_reg,
}
@pytest.fixture(scope="module", params=FLAVORS) @pytest.fixture(scope="module", params=[224, 518])
def resolution(request: pytest.FixtureRequest) -> int:
return request.param
@pytest.fixture(scope="module", params=FLAVORS_MAP.keys())
def flavor(request: pytest.FixtureRequest) -> str: def flavor(request: pytest.FixtureRequest) -> str:
return request.param return request.param
# Temporary: see comments in `test_encoder_only`
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def seed_expected_norm(flavor: str) -> tuple[int, float]: def dinov2_repo_path(test_repos_path: Path) -> Path:
match flavor: repo = test_repos_path / "dinov2"
case "dinov2_vits14": if not repo.exists():
return (42, 1977.9213867) warn(f"could not find DINOv2 GitHub repo at {repo}, skipping")
case "dinov2_vitb14": pytest.skip(allow_module_level=True)
return (42, 1902.6384277) return repo
case "dinov2_vitl14":
return (42, 1763.9187011)
case "dinov2_vits14_reg4":
return (42, 989.2380981)
case "dinov2_vitb14_reg4":
return (42, 974.4362182)
case "dinov2_vitl14_reg4":
return (42, 924.8797607)
case _:
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device) -> ViT: def ref_model(
flavor: str,
dinov2_repo_path: Path,
test_weights_path: Path,
test_device: torch.device,
) -> torch.nn.Module:
kwargs: dict[str, Any] = {}
if "reg" not in flavor:
kwargs["interpolate_offset"] = 0.0
model = torch.hub.load( # type: ignore
model=flavor,
repo_or_dir=str(dinov2_repo_path),
source="local",
pretrained=False, # to turn off automatic weights download (see load_state_dict below)
**kwargs,
).to(device=test_device)
flavor = flavor.replace("_reg", "_reg4")
weights = test_weights_path / f"{flavor}_pretrain.pth"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
model.load_state_dict(load_tensors(weights, device=test_device))
assert isinstance(model, torch.nn.Module)
return model
@pytest.fixture(scope="module")
def our_model(
test_weights_path: Path,
flavor: str,
test_device: torch.device,
) -> ViT:
model = FLAVORS_MAP[flavor](device=test_device)
flavor = flavor.replace("_reg", "_reg4")
weights = test_weights_path / f"{flavor}_pretrain.safetensors" weights = test_weights_path / f"{flavor}_pretrain.safetensors"
if not weights.is_file(): if not weights.is_file():
warn(f"could not find weights at {weights}, skipping") warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True) pytest.skip(allow_module_level=True)
match flavor:
case "dinov2_vits14":
backbone = DINOv2_small(device=test_device)
case "dinov2_vitb14":
backbone = DINOv2_base(device=test_device)
case "dinov2_vitl14":
backbone = DINOv2_large(device=test_device)
case "dinov2_vits14_reg4":
backbone = DINOv2_small_reg(device=test_device)
case "dinov2_vitb14_reg4":
backbone = DINOv2_base_reg(device=test_device)
case "dinov2_vitl14_reg4":
backbone = DINOv2_large_reg(device=test_device)
case _:
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
tensors = load_from_safetensors(weights) tensors = load_from_safetensors(weights)
backbone.load_state_dict(tensors) model.load_state_dict(tensors)
return backbone
return model
@pytest.fixture(scope="module") @no_grad()
def dinov2_weights_path(test_weights_path: Path, flavor: str): def test_dinov2_facebook_weights(
# TODO: At the time of writing, those are not yet supported in transformers ref_model: torch.nn.Module,
# (https://github.com/huggingface/transformers/issues/27379). Alternatively, it is also possible to use our_model: ViT,
# facebookresearch/dinov2 directly (https://github.com/finegrain-ai/refiners/pull/132). resolution: int,
if flavor.endswith("_reg4"):
warn(f"DINOv2 with registers are not yet supported in Hugging Face, skipping")
pytest.skip(allow_module_level=True)
match flavor:
case "dinov2_vits14":
name = "dinov2-small"
case "dinov2_vitb14":
name = "dinov2-base"
case "dinov2_vitl14":
name = "dinov2-large"
case _:
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
r = test_weights_path / "facebook" / name
if not r.is_dir():
warn(f"could not find DINOv2 weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
@pytest.fixture(scope="module")
def ref_backbone(dinov2_weights_path: Path, test_device: torch.device) -> Dinov2Model:
backbone = AutoModel.from_pretrained(dinov2_weights_path) # type: ignore
assert isinstance(backbone, Dinov2Model)
return backbone.to(test_device) # type: ignore
def test_encoder(
ref_backbone: Dinov2Model,
our_backbone: ViT,
test_device: torch.device, test_device: torch.device,
): ) -> None:
manual_seed(42) manual_seed(2)
input_data = torch.randn(
(1, 3, resolution, resolution),
device=test_device,
)
# Position encoding interpolation [1] at runtime is not supported yet. So stick to the default image resolution ref_output = ref_model(input_data, is_training=True)
# e.g. using (224, 224) pixels as input would give a runtime error (sequence size mismatch) ref_cls = ref_output["x_norm_clstoken"]
# [1]: https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179 ref_reg = ref_output["x_norm_regtokens"]
assert our_backbone.image_size == 518 ref_patch = ref_output["x_norm_patchtokens"]
x = torch.randn(1, 3, 518, 518).to(test_device) our_output = our_model(input_data)
our_cls = our_output[:, 0]
our_reg = our_output[:, 1 : our_model.num_registers + 1]
our_patch = our_output[:, our_model.num_registers + 1 :]
with no_grad(): assert torch.allclose(ref_cls, our_cls, atol=1e-4)
ref_features = ref_backbone(x).last_hidden_state assert torch.allclose(ref_reg, our_reg, atol=1e-4)
our_features = our_backbone(x) assert torch.allclose(ref_patch, our_patch, atol=3e-3)
assert (our_features - ref_features).abs().max() < 1e-3
# Mainly for DINOv2 + registers coverage (this test can be removed once `test_encoder` supports all flavors) @no_grad()
def test_encoder_only( def test_dinov2_float16(
our_backbone: ViT, resolution: int,
seed_expected_norm: tuple[int, float],
test_device: torch.device, test_device: torch.device,
): ) -> None:
seed, expected_norm = seed_expected_norm model = DINOv2_small(device=test_device, dtype=torch.float16)
manual_seed(seed)
x = torch.randn(1, 3, 518, 518).to(test_device) manual_seed(2)
input_data = torch.randn(
(1, 3, resolution, resolution),
device=test_device,
dtype=torch.float16,
)
our_features = our_backbone(x) output = model(input_data)
sequence_length = (resolution // model.patch_size) ** 2 + 1
assert isclose(our_features.norm().item(), expected_norm, rel_tol=1e-04) assert output.shape == (1, sequence_length, model.embedding_dim)
assert output.dtype == torch.float16