From 1a8ea9180fe87d982451ad07324b04db16e42d8e Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 29 Mar 2024 17:42:42 +0000 Subject: [PATCH] refactor dinov2 tests, check against official implementation --- .gitignore | 1 + CONTRIBUTING.md | 6 + scripts/prepare_test_weights.py | 10 - tests/conftest.py | 6 + tests/foundationals/dinov2/test_dinov2.py | 211 +++++++++++----------- 5 files changed, 118 insertions(+), 116 deletions(-) diff --git a/.gitignore b/.gitignore index 56a27d0..a8e9f40 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ venv/ # tests' model weights tests/weights/ +tests/repos/ # ruff .ruff_cache diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 95acdeb..d88899f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -52,6 +52,12 @@ Then, download and convert all the necessary weights. Be aware that this will us python scripts/prepare_test_weights.py ``` +Some tests require cloning the original implementation of the model as they use `torch.hub.load`: + +```bash +git clone git@github.com:facebookresearch/dinov2.git tests/repos/dinov2 +``` + Finally, run the tests: ```bash diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index 3dc96c2..eb759d6 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -388,16 +388,6 @@ def download_dinov2(): ] download_files(urls, weights_folder) - # For testing (note: versions with registers are not available yet on HuggingFace) - for repo in ["dinov2-small", "dinov2-base", "dinov2-large"]: - base_folder = os.path.join(test_weights_dir, "facebook", repo) - urls = [ - f"https://huggingface.co/facebook/{repo}/raw/main/config.json", - f"https://huggingface.co/facebook/{repo}/raw/main/preprocessor_config.json", - f"https://huggingface.co/facebook/{repo}/resolve/main/pytorch_model.bin", - ] - download_files(urls, base_folder) - def download_lcm_base(): base_folder = os.path.join(test_weights_dir, "latent-consistency/lcm-sdxl") diff --git a/tests/conftest.py b/tests/conftest.py index d1403ff..bcd3a69 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,12 @@ def test_weights_path() -> Path: return Path(from_env) if from_env else PARENT_PATH / "weights" +@fixture(scope="session") +def test_repos_path() -> Path: + from_env = os.getenv("REFINERS_TEST_REPOS_DIR") + return Path(from_env) if from_env else PARENT_PATH / "repos" + + @fixture(scope="session") def test_e2e_path() -> Path: return PARENT_PATH / "e2e" diff --git a/tests/foundationals/dinov2/test_dinov2.py b/tests/foundationals/dinov2/test_dinov2.py index 7bcf818..9f68d15 100644 --- a/tests/foundationals/dinov2/test_dinov2.py +++ b/tests/foundationals/dinov2/test_dinov2.py @@ -1,14 +1,12 @@ -from math import isclose from pathlib import Path +from typing import Any from warnings import warn import pytest import torch -from transformers import AutoModel # type: ignore -from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore -from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad -from refiners.foundationals.dinov2 import ( +from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_seed, no_grad +from refiners.foundationals.dinov2.dinov2 import ( DINOv2_base, DINOv2_base_reg, DINOv2_large, @@ -18,130 +16,131 @@ from refiners.foundationals.dinov2 import ( ) from refiners.foundationals.dinov2.vit import ViT -FLAVORS = [ - "dinov2_vits14", - "dinov2_vitb14", - "dinov2_vitl14", - "dinov2_vits14_reg4", - "dinov2_vitb14_reg4", - "dinov2_vitl14_reg4", -] +FLAVORS_MAP = { + "dinov2_vits14": DINOv2_small, + "dinov2_vits14_reg": DINOv2_small_reg, + "dinov2_vitb14": DINOv2_base, + "dinov2_vitb14_reg": DINOv2_base_reg, + "dinov2_vitl14": DINOv2_large, + "dinov2_vitl14_reg": DINOv2_large_reg, + # TODO: support giant flavors + # "dinov2_vitg14": DINOv2_giant, + # "dinov2_vitg14_reg": DINOv2_giant_reg, +} -@pytest.fixture(scope="module", params=FLAVORS) +@pytest.fixture(scope="module", params=[224, 518]) +def resolution(request: pytest.FixtureRequest) -> int: + return request.param + + +@pytest.fixture(scope="module", params=FLAVORS_MAP.keys()) def flavor(request: pytest.FixtureRequest) -> str: return request.param -# Temporary: see comments in `test_encoder_only` @pytest.fixture(scope="module") -def seed_expected_norm(flavor: str) -> tuple[int, float]: - match flavor: - case "dinov2_vits14": - return (42, 1977.9213867) - case "dinov2_vitb14": - return (42, 1902.6384277) - case "dinov2_vitl14": - return (42, 1763.9187011) - case "dinov2_vits14_reg4": - return (42, 989.2380981) - case "dinov2_vitb14_reg4": - return (42, 974.4362182) - case "dinov2_vitl14_reg4": - return (42, 924.8797607) - case _: - raise ValueError(f"Unexpected DINOv2 flavor: {flavor}") +def dinov2_repo_path(test_repos_path: Path) -> Path: + repo = test_repos_path / "dinov2" + if not repo.exists(): + warn(f"could not find DINOv2 GitHub repo at {repo}, skipping") + pytest.skip(allow_module_level=True) + return repo @pytest.fixture(scope="module") -def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device) -> ViT: +def ref_model( + flavor: str, + dinov2_repo_path: Path, + test_weights_path: Path, + test_device: torch.device, +) -> torch.nn.Module: + kwargs: dict[str, Any] = {} + if "reg" not in flavor: + kwargs["interpolate_offset"] = 0.0 + + model = torch.hub.load( # type: ignore + model=flavor, + repo_or_dir=str(dinov2_repo_path), + source="local", + pretrained=False, # to turn off automatic weights download (see load_state_dict below) + **kwargs, + ).to(device=test_device) + + flavor = flavor.replace("_reg", "_reg4") + weights = test_weights_path / f"{flavor}_pretrain.pth" + if not weights.is_file(): + warn(f"could not find weights at {weights}, skipping") + pytest.skip(allow_module_level=True) + model.load_state_dict(load_tensors(weights, device=test_device)) + + assert isinstance(model, torch.nn.Module) + return model + + +@pytest.fixture(scope="module") +def our_model( + test_weights_path: Path, + flavor: str, + test_device: torch.device, +) -> ViT: + model = FLAVORS_MAP[flavor](device=test_device) + + flavor = flavor.replace("_reg", "_reg4") weights = test_weights_path / f"{flavor}_pretrain.safetensors" if not weights.is_file(): warn(f"could not find weights at {weights}, skipping") pytest.skip(allow_module_level=True) - match flavor: - case "dinov2_vits14": - backbone = DINOv2_small(device=test_device) - case "dinov2_vitb14": - backbone = DINOv2_base(device=test_device) - case "dinov2_vitl14": - backbone = DINOv2_large(device=test_device) - case "dinov2_vits14_reg4": - backbone = DINOv2_small_reg(device=test_device) - case "dinov2_vitb14_reg4": - backbone = DINOv2_base_reg(device=test_device) - case "dinov2_vitl14_reg4": - backbone = DINOv2_large_reg(device=test_device) - case _: - raise ValueError(f"Unexpected DINOv2 flavor: {flavor}") + tensors = load_from_safetensors(weights) - backbone.load_state_dict(tensors) - return backbone + model.load_state_dict(tensors) + + return model -@pytest.fixture(scope="module") -def dinov2_weights_path(test_weights_path: Path, flavor: str): - # TODO: At the time of writing, those are not yet supported in transformers - # (https://github.com/huggingface/transformers/issues/27379). Alternatively, it is also possible to use - # facebookresearch/dinov2 directly (https://github.com/finegrain-ai/refiners/pull/132). - if flavor.endswith("_reg4"): - warn(f"DINOv2 with registers are not yet supported in Hugging Face, skipping") - pytest.skip(allow_module_level=True) - match flavor: - case "dinov2_vits14": - name = "dinov2-small" - case "dinov2_vitb14": - name = "dinov2-base" - case "dinov2_vitl14": - name = "dinov2-large" - case _: - raise ValueError(f"Unexpected DINOv2 flavor: {flavor}") - r = test_weights_path / "facebook" / name - if not r.is_dir(): - warn(f"could not find DINOv2 weights at {r}, skipping") - pytest.skip(allow_module_level=True) - return r - - -@pytest.fixture(scope="module") -def ref_backbone(dinov2_weights_path: Path, test_device: torch.device) -> Dinov2Model: - backbone = AutoModel.from_pretrained(dinov2_weights_path) # type: ignore - assert isinstance(backbone, Dinov2Model) - return backbone.to(test_device) # type: ignore - - -def test_encoder( - ref_backbone: Dinov2Model, - our_backbone: ViT, +@no_grad() +def test_dinov2_facebook_weights( + ref_model: torch.nn.Module, + our_model: ViT, + resolution: int, test_device: torch.device, -): - manual_seed(42) +) -> None: + manual_seed(2) + input_data = torch.randn( + (1, 3, resolution, resolution), + device=test_device, + ) - # Position encoding interpolation [1] at runtime is not supported yet. So stick to the default image resolution - # e.g. using (224, 224) pixels as input would give a runtime error (sequence size mismatch) - # [1]: https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179 - assert our_backbone.image_size == 518 + ref_output = ref_model(input_data, is_training=True) + ref_cls = ref_output["x_norm_clstoken"] + ref_reg = ref_output["x_norm_regtokens"] + ref_patch = ref_output["x_norm_patchtokens"] - x = torch.randn(1, 3, 518, 518).to(test_device) + our_output = our_model(input_data) + our_cls = our_output[:, 0] + our_reg = our_output[:, 1 : our_model.num_registers + 1] + our_patch = our_output[:, our_model.num_registers + 1 :] - with no_grad(): - ref_features = ref_backbone(x).last_hidden_state - our_features = our_backbone(x) - - assert (our_features - ref_features).abs().max() < 1e-3 + assert torch.allclose(ref_cls, our_cls, atol=1e-4) + assert torch.allclose(ref_reg, our_reg, atol=1e-4) + assert torch.allclose(ref_patch, our_patch, atol=3e-3) -# Mainly for DINOv2 + registers coverage (this test can be removed once `test_encoder` supports all flavors) -def test_encoder_only( - our_backbone: ViT, - seed_expected_norm: tuple[int, float], +@no_grad() +def test_dinov2_float16( + resolution: int, test_device: torch.device, -): - seed, expected_norm = seed_expected_norm - manual_seed(seed) +) -> None: + model = DINOv2_small(device=test_device, dtype=torch.float16) - x = torch.randn(1, 3, 518, 518).to(test_device) + manual_seed(2) + input_data = torch.randn( + (1, 3, resolution, resolution), + device=test_device, + dtype=torch.float16, + ) - our_features = our_backbone(x) - - assert isclose(our_features.norm().item(), expected_norm, rel_tol=1e-04) + output = model(input_data) + sequence_length = (resolution // model.patch_size) ** 2 + 1 + assert output.shape == (1, sequence_length, model.embedding_dim) + assert output.dtype == torch.float16