add DINOv2-FD metric

This commit is contained in:
Pierre Chapuis 2024-04-03 11:54:39 +02:00
parent c529006d13
commit 09af570b23
6 changed files with 214 additions and 2 deletions

1
.gitignore vendored
View file

@ -12,6 +12,7 @@ venv/
# tests' model weights
tests/weights/
tests/repos/
tests/datasets/
# ruff
.ruff_cache

View file

@ -5,6 +5,7 @@ from .dinov2 import (
DINOv2_large_reg,
DINOv2_small,
DINOv2_small_reg,
preprocess,
)
from .vit import ViT
@ -16,4 +17,5 @@ __all__ = [
"DINOv2_small",
"DINOv2_small_reg",
"ViT",
"preprocess",
]

View file

@ -1,9 +1,25 @@
import torch
from PIL import Image
from refiners.fluxion.utils import image_to_tensor, normalize
from refiners.foundationals.dinov2.vit import ViT
# TODO: add preprocessing logic like
# https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/data/transforms.py#L77
def preprocess(img: Image.Image, dim: int = 224) -> torch.Tensor:
"""
Preprocess an image for use with DINOv2. Uses ImageNet mean and standard deviation.
Note that this only resizes and normalizes the image, there is no center crop.
Args:
img: The image.
dim: The square dimension to resize the image. Typically 224 or 518.
Returns:
A float32 tensor with shape (3, dim, dim).
"""
img = img.convert("RGB").resize((dim, dim)) # type: ignore
t = image_to_tensor(img).squeeze()
return normalize(t, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
class DINOv2_small(ViT):

View file

@ -0,0 +1,117 @@
from pathlib import Path
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from refiners.foundationals import dinov2
def get_dinov2_representations(
model: dinov2.ViT,
dataloader: DataLoader[torch.Tensor],
dtype: torch.dtype = torch.float64,
) -> torch.Tensor:
"""
Get DINOV2 representations required to compute DINOv2-FD.
Args:
model: The DINOv2 model to use.
dataloader: A dataloader that returns batches of preprocessed images.
dtype: The dtype to use for the representations. Use float64 for good precision.
Returns:
A tensor with shape (batch, embedding_dim).
"""
r: list[torch.Tensor] = []
for batch in dataloader:
assert isinstance(batch, torch.Tensor)
batch_size = batch.shape[0]
assert batch.shape == (batch_size, 3, 224, 224)
batch = batch.to(model.device)
with torch.no_grad():
pred = model(batch)[:, 0] # only keep class embeddings
assert isinstance(pred, torch.Tensor)
assert pred.shape == (batch_size, model.embedding_dim)
r.append(pred.to(dtype))
return torch.cat(r)
def frechet_distance(reps_a: torch.Tensor, reps_b: torch.Tensor) -> float:
"""
Compute the Fréchet distance between two sets of representations.
Args:
reps_a: First set of representations (typically the reference). Shape (batch, N).
reps_a: Second set of representations (typically the test set). Shape (batch, N).
"""
assert reps_a.dim() == 2 and reps_b.dim() == 2, "representations must have shape (batch, N)"
assert reps_a.shape[1] == reps_b.shape[1], "representations must have the same dimension"
mean_a = torch.mean(reps_a, dim=0)
cov_a = torch.cov(reps_a.t())
mean_b = torch.mean(reps_b, dim=0)
cov_b = torch.cov(reps_b.t())
# The trace of the square root of a matrix is the sum of the square roots of its eigenvalues.
trace = (torch.linalg.eigvals(cov_a.mm(cov_b)) ** 0.5).real.sum() # type: ignore
assert isinstance(trace, torch.Tensor)
score = ((mean_a - mean_b) ** 2).sum() + cov_a.trace() + cov_b.trace() - 2 * trace
return score.item()
class DinoDataset(Dataset[torch.Tensor]):
def __init__(self, path: str | Path) -> None:
if isinstance(path, str):
path = Path(path)
self.image_paths = sorted(path.glob("*.png"))
def __len__(self) -> int:
return len(self.image_paths)
def __getitem__(self, i: int) -> torch.Tensor:
path = self.image_paths[i]
img = Image.open(path) # type: ignore
return dinov2.preprocess(img)
def dinov2_frechet_distance(
dataset_a: Dataset[torch.Tensor] | str | Path,
dataset_b: Dataset[torch.Tensor] | str | Path,
model: dinov2.ViT,
batch_size: int = 64,
dtype: torch.dtype = torch.float64,
) -> float:
"""
Compute DINOv2-based Fréchet Distance between two datasets.
There may be small discrepancies with other implementations due to the fact that DINOv2 in Refiners
uses the new style interpolation whereas DINOv2-FD historically uses the legacy implementation
(see https://github.com/facebookresearch/dinov2/pull/378)
Args:
dataset_a: First dataset (typically the reference). Can also be a path to a directory of PNG images.
If a dataset is passed, it must preprocess the data using `dinov2.preprocess`.
dataset_b: Second dataset (typically the test set). See `dataset_a` for details. Size can be different.
model: The DINOv2 model to use.
batch_size: The batch size to use.
dtype: The dtype to use for the representations. Use float64 for good precision.
"""
if not isinstance(dataset_a, Dataset):
dataset_a = DinoDataset(dataset_a)
if not isinstance(dataset_b, Dataset):
dataset_b = DinoDataset(dataset_b)
dataloader_a = DataLoader(dataset_a, batch_size=batch_size, shuffle=False)
dataloader_b = DataLoader(dataset_b, batch_size=batch_size, shuffle=False)
reps_a = get_dinov2_representations(model, dataloader_a, dtype)
reps_b = get_dinov2_representations(model, dataloader_b, dtype)
return frechet_distance(reps_a, reps_b)

View file

@ -6,6 +6,9 @@ from pytest import fixture
PARENT_PATH = Path(__file__).parent
collect_ignore = ["weights", "repos", "datasets"]
collect_ignore_glob = ["*_ref"]
@fixture(scope="session")
def test_device() -> torch.device:
@ -21,6 +24,12 @@ def test_weights_path() -> Path:
return Path(from_env) if from_env else PARENT_PATH / "weights"
@fixture(scope="session")
def test_datasets_path() -> Path:
from_env = os.getenv("REFINERS_TEST_DATASETS_DIR")
return Path(from_env) if from_env else PARENT_PATH / "datasets"
@fixture(scope="session")
def test_repos_path() -> Path:
from_env = os.getenv("REFINERS_TEST_REPOS_DIR")

View file

@ -0,0 +1,67 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
from torch.utils.data import Dataset
from torchvision.datasets import CIFAR10 # type: ignore
from refiners.foundationals import dinov2
from refiners.training_utils.metrics import dinov2_frechet_distance
class CifarDataset(Dataset[torch.Tensor]):
def __init__(self, ds: Dataset[list[torch.Tensor]], max_len: int = 512) -> None:
self.ds = ds
ds_length = len(self.ds) # type: ignore
self.length = min(ds_length, max_len)
def __len__(self) -> int:
return self.length
def __getitem__(self, i: int) -> torch.Tensor:
return self.ds[i][0]
@pytest.fixture(scope="module")
def dinov2_l(
test_weights_path: Path,
test_device: torch.device,
) -> dinov2.DINOv2_large:
weights = test_weights_path / f"dinov2_vitl14_pretrain.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
model = dinov2.DINOv2_large(device=test_device)
model.load_from_safetensors(weights)
return model
def test_dinov2_frechet_distance(test_datasets_path: Path, dinov2_l: dinov2.DINOv2_large) -> None:
path = str(test_datasets_path / "CIFAR10")
ds_train = CifarDataset(
CIFAR10(
root=path,
train=True,
download=True,
transform=dinov2.preprocess,
)
)
ds_test = CifarDataset(
CIFAR10(
root=path,
train=False,
download=True,
transform=dinov2.preprocess,
)
)
# Computed using dgm-eval (https://github.com/layer6ai-labs/dgm-eval)
# with interpolate_offset=0 and random_sample=False.
expected_d = 837.978
d = dinov2_frechet_distance(ds_train, ds_test, dinov2_l)
assert expected_d - 1e-2 < d < expected_d + 1e-2