mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 14:48:45 +00:00
147 lines
4.1 KiB
Python
147 lines
4.1 KiB
Python
from pathlib import Path
|
|
from typing import Any
|
|
from warnings import warn
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_seed, no_grad
|
|
from refiners.foundationals.dinov2.dinov2 import (
|
|
DINOv2_base,
|
|
DINOv2_base_reg,
|
|
DINOv2_large,
|
|
DINOv2_large_reg,
|
|
DINOv2_small,
|
|
DINOv2_small_reg,
|
|
)
|
|
from refiners.foundationals.dinov2.vit import ViT
|
|
|
|
FLAVORS_MAP = {
|
|
"dinov2_vits14": DINOv2_small,
|
|
"dinov2_vits14_reg": DINOv2_small_reg,
|
|
"dinov2_vitb14": DINOv2_base,
|
|
"dinov2_vitb14_reg": DINOv2_base_reg,
|
|
"dinov2_vitl14": DINOv2_large,
|
|
"dinov2_vitl14_reg": DINOv2_large_reg,
|
|
# TODO: support giant flavors
|
|
# "dinov2_vitg14": DINOv2_giant,
|
|
# "dinov2_vitg14_reg": DINOv2_giant_reg,
|
|
}
|
|
|
|
|
|
@pytest.fixture(scope="module", params=[224, 518])
|
|
def resolution(request: pytest.FixtureRequest) -> int:
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture(scope="module", params=FLAVORS_MAP.keys())
|
|
def flavor(request: pytest.FixtureRequest) -> str:
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def dinov2_repo_path(test_repos_path: Path) -> Path:
|
|
repo = test_repos_path / "dinov2"
|
|
if not repo.exists():
|
|
warn(f"could not find DINOv2 GitHub repo at {repo}, skipping")
|
|
pytest.skip(allow_module_level=True)
|
|
return repo
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def ref_model(
|
|
flavor: str,
|
|
dinov2_repo_path: Path,
|
|
test_weights_path: Path,
|
|
test_device: torch.device,
|
|
) -> torch.nn.Module:
|
|
kwargs: dict[str, Any] = {}
|
|
if "reg" not in flavor:
|
|
kwargs["interpolate_offset"] = 0.0
|
|
|
|
model = torch.hub.load( # type: ignore
|
|
model=flavor,
|
|
repo_or_dir=str(dinov2_repo_path),
|
|
source="local",
|
|
pretrained=False, # to turn off automatic weights download (see load_state_dict below)
|
|
**kwargs,
|
|
).to(device=test_device)
|
|
|
|
flavor = flavor.replace("_reg", "_reg4")
|
|
weights = test_weights_path / f"{flavor}_pretrain.pth"
|
|
if not weights.is_file():
|
|
warn(f"could not find weights at {weights}, skipping")
|
|
pytest.skip(allow_module_level=True)
|
|
model.load_state_dict(load_tensors(weights, device=test_device))
|
|
|
|
assert isinstance(model, torch.nn.Module)
|
|
return model
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def our_model(
|
|
test_weights_path: Path,
|
|
flavor: str,
|
|
test_device: torch.device,
|
|
) -> ViT:
|
|
model = FLAVORS_MAP[flavor](device=test_device)
|
|
|
|
flavor = flavor.replace("_reg", "_reg4")
|
|
weights = test_weights_path / f"{flavor}_pretrain.safetensors"
|
|
if not weights.is_file():
|
|
warn(f"could not find weights at {weights}, skipping")
|
|
pytest.skip(allow_module_level=True)
|
|
|
|
tensors = load_from_safetensors(weights)
|
|
model.load_state_dict(tensors)
|
|
|
|
return model
|
|
|
|
|
|
@no_grad()
|
|
def test_dinov2_facebook_weights(
|
|
ref_model: torch.nn.Module,
|
|
our_model: ViT,
|
|
resolution: int,
|
|
test_device: torch.device,
|
|
) -> None:
|
|
manual_seed(2)
|
|
input_data = torch.randn(
|
|
(1, 3, resolution, resolution),
|
|
device=test_device,
|
|
)
|
|
|
|
ref_output = ref_model(input_data, is_training=True)
|
|
ref_cls = ref_output["x_norm_clstoken"]
|
|
ref_reg = ref_output["x_norm_regtokens"]
|
|
ref_patch = ref_output["x_norm_patchtokens"]
|
|
|
|
our_output = our_model(input_data)
|
|
our_cls = our_output[:, 0]
|
|
our_reg = our_output[:, 1 : our_model.num_registers + 1]
|
|
our_patch = our_output[:, our_model.num_registers + 1 :]
|
|
|
|
assert torch.allclose(ref_cls, our_cls, atol=1e-4)
|
|
assert torch.allclose(ref_reg, our_reg, atol=1e-4)
|
|
assert torch.allclose(ref_patch, our_patch, atol=3e-3)
|
|
|
|
|
|
@no_grad()
|
|
def test_dinov2_float16(
|
|
resolution: int,
|
|
test_device: torch.device,
|
|
) -> None:
|
|
model = DINOv2_small(device=test_device, dtype=torch.float16)
|
|
|
|
manual_seed(2)
|
|
input_data = torch.randn(
|
|
(1, 3, resolution, resolution),
|
|
device=test_device,
|
|
dtype=torch.float16,
|
|
)
|
|
|
|
output = model(input_data)
|
|
sequence_length = (resolution // model.patch_size) ** 2 + 1
|
|
assert output.shape == (1, sequence_length, model.embedding_dim)
|
|
assert output.dtype == torch.float16
|