mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
refactor dinov2 tests, check against official implementation
This commit is contained in:
parent
4f94dfb494
commit
1a8ea9180f
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -11,6 +11,7 @@ venv/
|
|||
|
||||
# tests' model weights
|
||||
tests/weights/
|
||||
tests/repos/
|
||||
|
||||
# ruff
|
||||
.ruff_cache
|
||||
|
|
|
@ -52,6 +52,12 @@ Then, download and convert all the necessary weights. Be aware that this will us
|
|||
python scripts/prepare_test_weights.py
|
||||
```
|
||||
|
||||
Some tests require cloning the original implementation of the model as they use `torch.hub.load`:
|
||||
|
||||
```bash
|
||||
git clone git@github.com:facebookresearch/dinov2.git tests/repos/dinov2
|
||||
```
|
||||
|
||||
Finally, run the tests:
|
||||
|
||||
```bash
|
||||
|
|
|
@ -388,16 +388,6 @@ def download_dinov2():
|
|||
]
|
||||
download_files(urls, weights_folder)
|
||||
|
||||
# For testing (note: versions with registers are not available yet on HuggingFace)
|
||||
for repo in ["dinov2-small", "dinov2-base", "dinov2-large"]:
|
||||
base_folder = os.path.join(test_weights_dir, "facebook", repo)
|
||||
urls = [
|
||||
f"https://huggingface.co/facebook/{repo}/raw/main/config.json",
|
||||
f"https://huggingface.co/facebook/{repo}/raw/main/preprocessor_config.json",
|
||||
f"https://huggingface.co/facebook/{repo}/resolve/main/pytorch_model.bin",
|
||||
]
|
||||
download_files(urls, base_folder)
|
||||
|
||||
|
||||
def download_lcm_base():
|
||||
base_folder = os.path.join(test_weights_dir, "latent-consistency/lcm-sdxl")
|
||||
|
|
|
@ -21,6 +21,12 @@ def test_weights_path() -> Path:
|
|||
return Path(from_env) if from_env else PARENT_PATH / "weights"
|
||||
|
||||
|
||||
@fixture(scope="session")
|
||||
def test_repos_path() -> Path:
|
||||
from_env = os.getenv("REFINERS_TEST_REPOS_DIR")
|
||||
return Path(from_env) if from_env else PARENT_PATH / "repos"
|
||||
|
||||
|
||||
@fixture(scope="session")
|
||||
def test_e2e_path() -> Path:
|
||||
return PARENT_PATH / "e2e"
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
from math import isclose
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from warnings import warn
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModel # type: ignore
|
||||
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore
|
||||
|
||||
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
|
||||
from refiners.foundationals.dinov2 import (
|
||||
from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_seed, no_grad
|
||||
from refiners.foundationals.dinov2.dinov2 import (
|
||||
DINOv2_base,
|
||||
DINOv2_base_reg,
|
||||
DINOv2_large,
|
||||
|
@ -18,130 +16,131 @@ from refiners.foundationals.dinov2 import (
|
|||
)
|
||||
from refiners.foundationals.dinov2.vit import ViT
|
||||
|
||||
FLAVORS = [
|
||||
"dinov2_vits14",
|
||||
"dinov2_vitb14",
|
||||
"dinov2_vitl14",
|
||||
"dinov2_vits14_reg4",
|
||||
"dinov2_vitb14_reg4",
|
||||
"dinov2_vitl14_reg4",
|
||||
]
|
||||
FLAVORS_MAP = {
|
||||
"dinov2_vits14": DINOv2_small,
|
||||
"dinov2_vits14_reg": DINOv2_small_reg,
|
||||
"dinov2_vitb14": DINOv2_base,
|
||||
"dinov2_vitb14_reg": DINOv2_base_reg,
|
||||
"dinov2_vitl14": DINOv2_large,
|
||||
"dinov2_vitl14_reg": DINOv2_large_reg,
|
||||
# TODO: support giant flavors
|
||||
# "dinov2_vitg14": DINOv2_giant,
|
||||
# "dinov2_vitg14_reg": DINOv2_giant_reg,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=FLAVORS)
|
||||
@pytest.fixture(scope="module", params=[224, 518])
|
||||
def resolution(request: pytest.FixtureRequest) -> int:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=FLAVORS_MAP.keys())
|
||||
def flavor(request: pytest.FixtureRequest) -> str:
|
||||
return request.param
|
||||
|
||||
|
||||
# Temporary: see comments in `test_encoder_only`
|
||||
@pytest.fixture(scope="module")
|
||||
def seed_expected_norm(flavor: str) -> tuple[int, float]:
|
||||
match flavor:
|
||||
case "dinov2_vits14":
|
||||
return (42, 1977.9213867)
|
||||
case "dinov2_vitb14":
|
||||
return (42, 1902.6384277)
|
||||
case "dinov2_vitl14":
|
||||
return (42, 1763.9187011)
|
||||
case "dinov2_vits14_reg4":
|
||||
return (42, 989.2380981)
|
||||
case "dinov2_vitb14_reg4":
|
||||
return (42, 974.4362182)
|
||||
case "dinov2_vitl14_reg4":
|
||||
return (42, 924.8797607)
|
||||
case _:
|
||||
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
|
||||
def dinov2_repo_path(test_repos_path: Path) -> Path:
|
||||
repo = test_repos_path / "dinov2"
|
||||
if not repo.exists():
|
||||
warn(f"could not find DINOv2 GitHub repo at {repo}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device) -> ViT:
|
||||
def ref_model(
|
||||
flavor: str,
|
||||
dinov2_repo_path: Path,
|
||||
test_weights_path: Path,
|
||||
test_device: torch.device,
|
||||
) -> torch.nn.Module:
|
||||
kwargs: dict[str, Any] = {}
|
||||
if "reg" not in flavor:
|
||||
kwargs["interpolate_offset"] = 0.0
|
||||
|
||||
model = torch.hub.load( # type: ignore
|
||||
model=flavor,
|
||||
repo_or_dir=str(dinov2_repo_path),
|
||||
source="local",
|
||||
pretrained=False, # to turn off automatic weights download (see load_state_dict below)
|
||||
**kwargs,
|
||||
).to(device=test_device)
|
||||
|
||||
flavor = flavor.replace("_reg", "_reg4")
|
||||
weights = test_weights_path / f"{flavor}_pretrain.pth"
|
||||
if not weights.is_file():
|
||||
warn(f"could not find weights at {weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
model.load_state_dict(load_tensors(weights, device=test_device))
|
||||
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def our_model(
|
||||
test_weights_path: Path,
|
||||
flavor: str,
|
||||
test_device: torch.device,
|
||||
) -> ViT:
|
||||
model = FLAVORS_MAP[flavor](device=test_device)
|
||||
|
||||
flavor = flavor.replace("_reg", "_reg4")
|
||||
weights = test_weights_path / f"{flavor}_pretrain.safetensors"
|
||||
if not weights.is_file():
|
||||
warn(f"could not find weights at {weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
match flavor:
|
||||
case "dinov2_vits14":
|
||||
backbone = DINOv2_small(device=test_device)
|
||||
case "dinov2_vitb14":
|
||||
backbone = DINOv2_base(device=test_device)
|
||||
case "dinov2_vitl14":
|
||||
backbone = DINOv2_large(device=test_device)
|
||||
case "dinov2_vits14_reg4":
|
||||
backbone = DINOv2_small_reg(device=test_device)
|
||||
case "dinov2_vitb14_reg4":
|
||||
backbone = DINOv2_base_reg(device=test_device)
|
||||
case "dinov2_vitl14_reg4":
|
||||
backbone = DINOv2_large_reg(device=test_device)
|
||||
case _:
|
||||
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
|
||||
|
||||
tensors = load_from_safetensors(weights)
|
||||
backbone.load_state_dict(tensors)
|
||||
return backbone
|
||||
model.load_state_dict(tensors)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dinov2_weights_path(test_weights_path: Path, flavor: str):
|
||||
# TODO: At the time of writing, those are not yet supported in transformers
|
||||
# (https://github.com/huggingface/transformers/issues/27379). Alternatively, it is also possible to use
|
||||
# facebookresearch/dinov2 directly (https://github.com/finegrain-ai/refiners/pull/132).
|
||||
if flavor.endswith("_reg4"):
|
||||
warn(f"DINOv2 with registers are not yet supported in Hugging Face, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
match flavor:
|
||||
case "dinov2_vits14":
|
||||
name = "dinov2-small"
|
||||
case "dinov2_vitb14":
|
||||
name = "dinov2-base"
|
||||
case "dinov2_vitl14":
|
||||
name = "dinov2-large"
|
||||
case _:
|
||||
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
|
||||
r = test_weights_path / "facebook" / name
|
||||
if not r.is_dir():
|
||||
warn(f"could not find DINOv2 weights at {r}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return r
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ref_backbone(dinov2_weights_path: Path, test_device: torch.device) -> Dinov2Model:
|
||||
backbone = AutoModel.from_pretrained(dinov2_weights_path) # type: ignore
|
||||
assert isinstance(backbone, Dinov2Model)
|
||||
return backbone.to(test_device) # type: ignore
|
||||
|
||||
|
||||
def test_encoder(
|
||||
ref_backbone: Dinov2Model,
|
||||
our_backbone: ViT,
|
||||
@no_grad()
|
||||
def test_dinov2_facebook_weights(
|
||||
ref_model: torch.nn.Module,
|
||||
our_model: ViT,
|
||||
resolution: int,
|
||||
test_device: torch.device,
|
||||
):
|
||||
manual_seed(42)
|
||||
) -> None:
|
||||
manual_seed(2)
|
||||
input_data = torch.randn(
|
||||
(1, 3, resolution, resolution),
|
||||
device=test_device,
|
||||
)
|
||||
|
||||
# Position encoding interpolation [1] at runtime is not supported yet. So stick to the default image resolution
|
||||
# e.g. using (224, 224) pixels as input would give a runtime error (sequence size mismatch)
|
||||
# [1]: https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179
|
||||
assert our_backbone.image_size == 518
|
||||
ref_output = ref_model(input_data, is_training=True)
|
||||
ref_cls = ref_output["x_norm_clstoken"]
|
||||
ref_reg = ref_output["x_norm_regtokens"]
|
||||
ref_patch = ref_output["x_norm_patchtokens"]
|
||||
|
||||
x = torch.randn(1, 3, 518, 518).to(test_device)
|
||||
our_output = our_model(input_data)
|
||||
our_cls = our_output[:, 0]
|
||||
our_reg = our_output[:, 1 : our_model.num_registers + 1]
|
||||
our_patch = our_output[:, our_model.num_registers + 1 :]
|
||||
|
||||
with no_grad():
|
||||
ref_features = ref_backbone(x).last_hidden_state
|
||||
our_features = our_backbone(x)
|
||||
|
||||
assert (our_features - ref_features).abs().max() < 1e-3
|
||||
assert torch.allclose(ref_cls, our_cls, atol=1e-4)
|
||||
assert torch.allclose(ref_reg, our_reg, atol=1e-4)
|
||||
assert torch.allclose(ref_patch, our_patch, atol=3e-3)
|
||||
|
||||
|
||||
# Mainly for DINOv2 + registers coverage (this test can be removed once `test_encoder` supports all flavors)
|
||||
def test_encoder_only(
|
||||
our_backbone: ViT,
|
||||
seed_expected_norm: tuple[int, float],
|
||||
@no_grad()
|
||||
def test_dinov2_float16(
|
||||
resolution: int,
|
||||
test_device: torch.device,
|
||||
):
|
||||
seed, expected_norm = seed_expected_norm
|
||||
manual_seed(seed)
|
||||
) -> None:
|
||||
model = DINOv2_small(device=test_device, dtype=torch.float16)
|
||||
|
||||
x = torch.randn(1, 3, 518, 518).to(test_device)
|
||||
manual_seed(2)
|
||||
input_data = torch.randn(
|
||||
(1, 3, resolution, resolution),
|
||||
device=test_device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
our_features = our_backbone(x)
|
||||
|
||||
assert isclose(our_features.norm().item(), expected_norm, rel_tol=1e-04)
|
||||
output = model(input_data)
|
||||
sequence_length = (resolution // model.patch_size) ** 2 + 1
|
||||
assert output.shape == (1, sequence_length, model.embedding_dim)
|
||||
assert output.dtype == torch.float16
|
||||
|
|
Loading…
Reference in a new issue