mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
refactor dinov2 tests, check against official implementation
This commit is contained in:
parent
4f94dfb494
commit
1a8ea9180f
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -11,6 +11,7 @@ venv/
|
||||||
|
|
||||||
# tests' model weights
|
# tests' model weights
|
||||||
tests/weights/
|
tests/weights/
|
||||||
|
tests/repos/
|
||||||
|
|
||||||
# ruff
|
# ruff
|
||||||
.ruff_cache
|
.ruff_cache
|
||||||
|
|
|
@ -52,6 +52,12 @@ Then, download and convert all the necessary weights. Be aware that this will us
|
||||||
python scripts/prepare_test_weights.py
|
python scripts/prepare_test_weights.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Some tests require cloning the original implementation of the model as they use `torch.hub.load`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone git@github.com:facebookresearch/dinov2.git tests/repos/dinov2
|
||||||
|
```
|
||||||
|
|
||||||
Finally, run the tests:
|
Finally, run the tests:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
@ -388,16 +388,6 @@ def download_dinov2():
|
||||||
]
|
]
|
||||||
download_files(urls, weights_folder)
|
download_files(urls, weights_folder)
|
||||||
|
|
||||||
# For testing (note: versions with registers are not available yet on HuggingFace)
|
|
||||||
for repo in ["dinov2-small", "dinov2-base", "dinov2-large"]:
|
|
||||||
base_folder = os.path.join(test_weights_dir, "facebook", repo)
|
|
||||||
urls = [
|
|
||||||
f"https://huggingface.co/facebook/{repo}/raw/main/config.json",
|
|
||||||
f"https://huggingface.co/facebook/{repo}/raw/main/preprocessor_config.json",
|
|
||||||
f"https://huggingface.co/facebook/{repo}/resolve/main/pytorch_model.bin",
|
|
||||||
]
|
|
||||||
download_files(urls, base_folder)
|
|
||||||
|
|
||||||
|
|
||||||
def download_lcm_base():
|
def download_lcm_base():
|
||||||
base_folder = os.path.join(test_weights_dir, "latent-consistency/lcm-sdxl")
|
base_folder = os.path.join(test_weights_dir, "latent-consistency/lcm-sdxl")
|
||||||
|
|
|
@ -21,6 +21,12 @@ def test_weights_path() -> Path:
|
||||||
return Path(from_env) if from_env else PARENT_PATH / "weights"
|
return Path(from_env) if from_env else PARENT_PATH / "weights"
|
||||||
|
|
||||||
|
|
||||||
|
@fixture(scope="session")
|
||||||
|
def test_repos_path() -> Path:
|
||||||
|
from_env = os.getenv("REFINERS_TEST_REPOS_DIR")
|
||||||
|
return Path(from_env) if from_env else PARENT_PATH / "repos"
|
||||||
|
|
||||||
|
|
||||||
@fixture(scope="session")
|
@fixture(scope="session")
|
||||||
def test_e2e_path() -> Path:
|
def test_e2e_path() -> Path:
|
||||||
return PARENT_PATH / "e2e"
|
return PARENT_PATH / "e2e"
|
||||||
|
|
|
@ -1,14 +1,12 @@
|
||||||
from math import isclose
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModel # type: ignore
|
|
||||||
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore
|
|
||||||
|
|
||||||
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
|
from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_seed, no_grad
|
||||||
from refiners.foundationals.dinov2 import (
|
from refiners.foundationals.dinov2.dinov2 import (
|
||||||
DINOv2_base,
|
DINOv2_base,
|
||||||
DINOv2_base_reg,
|
DINOv2_base_reg,
|
||||||
DINOv2_large,
|
DINOv2_large,
|
||||||
|
@ -18,130 +16,131 @@ from refiners.foundationals.dinov2 import (
|
||||||
)
|
)
|
||||||
from refiners.foundationals.dinov2.vit import ViT
|
from refiners.foundationals.dinov2.vit import ViT
|
||||||
|
|
||||||
FLAVORS = [
|
FLAVORS_MAP = {
|
||||||
"dinov2_vits14",
|
"dinov2_vits14": DINOv2_small,
|
||||||
"dinov2_vitb14",
|
"dinov2_vits14_reg": DINOv2_small_reg,
|
||||||
"dinov2_vitl14",
|
"dinov2_vitb14": DINOv2_base,
|
||||||
"dinov2_vits14_reg4",
|
"dinov2_vitb14_reg": DINOv2_base_reg,
|
||||||
"dinov2_vitb14_reg4",
|
"dinov2_vitl14": DINOv2_large,
|
||||||
"dinov2_vitl14_reg4",
|
"dinov2_vitl14_reg": DINOv2_large_reg,
|
||||||
]
|
# TODO: support giant flavors
|
||||||
|
# "dinov2_vitg14": DINOv2_giant,
|
||||||
|
# "dinov2_vitg14_reg": DINOv2_giant_reg,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=FLAVORS)
|
@pytest.fixture(scope="module", params=[224, 518])
|
||||||
|
def resolution(request: pytest.FixtureRequest) -> int:
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=FLAVORS_MAP.keys())
|
||||||
def flavor(request: pytest.FixtureRequest) -> str:
|
def flavor(request: pytest.FixtureRequest) -> str:
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
# Temporary: see comments in `test_encoder_only`
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def seed_expected_norm(flavor: str) -> tuple[int, float]:
|
def dinov2_repo_path(test_repos_path: Path) -> Path:
|
||||||
match flavor:
|
repo = test_repos_path / "dinov2"
|
||||||
case "dinov2_vits14":
|
if not repo.exists():
|
||||||
return (42, 1977.9213867)
|
warn(f"could not find DINOv2 GitHub repo at {repo}, skipping")
|
||||||
case "dinov2_vitb14":
|
pytest.skip(allow_module_level=True)
|
||||||
return (42, 1902.6384277)
|
return repo
|
||||||
case "dinov2_vitl14":
|
|
||||||
return (42, 1763.9187011)
|
|
||||||
case "dinov2_vits14_reg4":
|
|
||||||
return (42, 989.2380981)
|
|
||||||
case "dinov2_vitb14_reg4":
|
|
||||||
return (42, 974.4362182)
|
|
||||||
case "dinov2_vitl14_reg4":
|
|
||||||
return (42, 924.8797607)
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device) -> ViT:
|
def ref_model(
|
||||||
|
flavor: str,
|
||||||
|
dinov2_repo_path: Path,
|
||||||
|
test_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
kwargs: dict[str, Any] = {}
|
||||||
|
if "reg" not in flavor:
|
||||||
|
kwargs["interpolate_offset"] = 0.0
|
||||||
|
|
||||||
|
model = torch.hub.load( # type: ignore
|
||||||
|
model=flavor,
|
||||||
|
repo_or_dir=str(dinov2_repo_path),
|
||||||
|
source="local",
|
||||||
|
pretrained=False, # to turn off automatic weights download (see load_state_dict below)
|
||||||
|
**kwargs,
|
||||||
|
).to(device=test_device)
|
||||||
|
|
||||||
|
flavor = flavor.replace("_reg", "_reg4")
|
||||||
|
weights = test_weights_path / f"{flavor}_pretrain.pth"
|
||||||
|
if not weights.is_file():
|
||||||
|
warn(f"could not find weights at {weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
model.load_state_dict(load_tensors(weights, device=test_device))
|
||||||
|
|
||||||
|
assert isinstance(model, torch.nn.Module)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def our_model(
|
||||||
|
test_weights_path: Path,
|
||||||
|
flavor: str,
|
||||||
|
test_device: torch.device,
|
||||||
|
) -> ViT:
|
||||||
|
model = FLAVORS_MAP[flavor](device=test_device)
|
||||||
|
|
||||||
|
flavor = flavor.replace("_reg", "_reg4")
|
||||||
weights = test_weights_path / f"{flavor}_pretrain.safetensors"
|
weights = test_weights_path / f"{flavor}_pretrain.safetensors"
|
||||||
if not weights.is_file():
|
if not weights.is_file():
|
||||||
warn(f"could not find weights at {weights}, skipping")
|
warn(f"could not find weights at {weights}, skipping")
|
||||||
pytest.skip(allow_module_level=True)
|
pytest.skip(allow_module_level=True)
|
||||||
match flavor:
|
|
||||||
case "dinov2_vits14":
|
|
||||||
backbone = DINOv2_small(device=test_device)
|
|
||||||
case "dinov2_vitb14":
|
|
||||||
backbone = DINOv2_base(device=test_device)
|
|
||||||
case "dinov2_vitl14":
|
|
||||||
backbone = DINOv2_large(device=test_device)
|
|
||||||
case "dinov2_vits14_reg4":
|
|
||||||
backbone = DINOv2_small_reg(device=test_device)
|
|
||||||
case "dinov2_vitb14_reg4":
|
|
||||||
backbone = DINOv2_base_reg(device=test_device)
|
|
||||||
case "dinov2_vitl14_reg4":
|
|
||||||
backbone = DINOv2_large_reg(device=test_device)
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
|
|
||||||
tensors = load_from_safetensors(weights)
|
tensors = load_from_safetensors(weights)
|
||||||
backbone.load_state_dict(tensors)
|
model.load_state_dict(tensors)
|
||||||
return backbone
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@no_grad()
|
||||||
def dinov2_weights_path(test_weights_path: Path, flavor: str):
|
def test_dinov2_facebook_weights(
|
||||||
# TODO: At the time of writing, those are not yet supported in transformers
|
ref_model: torch.nn.Module,
|
||||||
# (https://github.com/huggingface/transformers/issues/27379). Alternatively, it is also possible to use
|
our_model: ViT,
|
||||||
# facebookresearch/dinov2 directly (https://github.com/finegrain-ai/refiners/pull/132).
|
resolution: int,
|
||||||
if flavor.endswith("_reg4"):
|
|
||||||
warn(f"DINOv2 with registers are not yet supported in Hugging Face, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
match flavor:
|
|
||||||
case "dinov2_vits14":
|
|
||||||
name = "dinov2-small"
|
|
||||||
case "dinov2_vitb14":
|
|
||||||
name = "dinov2-base"
|
|
||||||
case "dinov2_vitl14":
|
|
||||||
name = "dinov2-large"
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
|
|
||||||
r = test_weights_path / "facebook" / name
|
|
||||||
if not r.is_dir():
|
|
||||||
warn(f"could not find DINOv2 weights at {r}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def ref_backbone(dinov2_weights_path: Path, test_device: torch.device) -> Dinov2Model:
|
|
||||||
backbone = AutoModel.from_pretrained(dinov2_weights_path) # type: ignore
|
|
||||||
assert isinstance(backbone, Dinov2Model)
|
|
||||||
return backbone.to(test_device) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def test_encoder(
|
|
||||||
ref_backbone: Dinov2Model,
|
|
||||||
our_backbone: ViT,
|
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
):
|
) -> None:
|
||||||
manual_seed(42)
|
manual_seed(2)
|
||||||
|
input_data = torch.randn(
|
||||||
|
(1, 3, resolution, resolution),
|
||||||
|
device=test_device,
|
||||||
|
)
|
||||||
|
|
||||||
# Position encoding interpolation [1] at runtime is not supported yet. So stick to the default image resolution
|
ref_output = ref_model(input_data, is_training=True)
|
||||||
# e.g. using (224, 224) pixels as input would give a runtime error (sequence size mismatch)
|
ref_cls = ref_output["x_norm_clstoken"]
|
||||||
# [1]: https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179
|
ref_reg = ref_output["x_norm_regtokens"]
|
||||||
assert our_backbone.image_size == 518
|
ref_patch = ref_output["x_norm_patchtokens"]
|
||||||
|
|
||||||
x = torch.randn(1, 3, 518, 518).to(test_device)
|
our_output = our_model(input_data)
|
||||||
|
our_cls = our_output[:, 0]
|
||||||
|
our_reg = our_output[:, 1 : our_model.num_registers + 1]
|
||||||
|
our_patch = our_output[:, our_model.num_registers + 1 :]
|
||||||
|
|
||||||
with no_grad():
|
assert torch.allclose(ref_cls, our_cls, atol=1e-4)
|
||||||
ref_features = ref_backbone(x).last_hidden_state
|
assert torch.allclose(ref_reg, our_reg, atol=1e-4)
|
||||||
our_features = our_backbone(x)
|
assert torch.allclose(ref_patch, our_patch, atol=3e-3)
|
||||||
|
|
||||||
assert (our_features - ref_features).abs().max() < 1e-3
|
|
||||||
|
|
||||||
|
|
||||||
# Mainly for DINOv2 + registers coverage (this test can be removed once `test_encoder` supports all flavors)
|
@no_grad()
|
||||||
def test_encoder_only(
|
def test_dinov2_float16(
|
||||||
our_backbone: ViT,
|
resolution: int,
|
||||||
seed_expected_norm: tuple[int, float],
|
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
):
|
) -> None:
|
||||||
seed, expected_norm = seed_expected_norm
|
model = DINOv2_small(device=test_device, dtype=torch.float16)
|
||||||
manual_seed(seed)
|
|
||||||
|
|
||||||
x = torch.randn(1, 3, 518, 518).to(test_device)
|
manual_seed(2)
|
||||||
|
input_data = torch.randn(
|
||||||
|
(1, 3, resolution, resolution),
|
||||||
|
device=test_device,
|
||||||
|
dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
our_features = our_backbone(x)
|
output = model(input_data)
|
||||||
|
sequence_length = (resolution // model.patch_size) ** 2 + 1
|
||||||
assert isclose(our_features.norm().item(), expected_norm, rel_tol=1e-04)
|
assert output.shape == (1, sequence_length, model.embedding_dim)
|
||||||
|
assert output.dtype == torch.float16
|
||||||
|
|
Loading…
Reference in a new issue