mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-12 16:18:22 +00:00
add DINOv2-FD metric
This commit is contained in:
parent
c529006d13
commit
09af570b23
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -12,6 +12,7 @@ venv/
|
|||
# tests' model weights
|
||||
tests/weights/
|
||||
tests/repos/
|
||||
tests/datasets/
|
||||
|
||||
# ruff
|
||||
.ruff_cache
|
||||
|
|
|
@ -5,6 +5,7 @@ from .dinov2 import (
|
|||
DINOv2_large_reg,
|
||||
DINOv2_small,
|
||||
DINOv2_small_reg,
|
||||
preprocess,
|
||||
)
|
||||
from .vit import ViT
|
||||
|
||||
|
@ -16,4 +17,5 @@ __all__ = [
|
|||
"DINOv2_small",
|
||||
"DINOv2_small_reg",
|
||||
"ViT",
|
||||
"preprocess",
|
||||
]
|
||||
|
|
|
@ -1,9 +1,25 @@
|
|||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from refiners.fluxion.utils import image_to_tensor, normalize
|
||||
from refiners.foundationals.dinov2.vit import ViT
|
||||
|
||||
# TODO: add preprocessing logic like
|
||||
# https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/data/transforms.py#L77
|
||||
|
||||
def preprocess(img: Image.Image, dim: int = 224) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess an image for use with DINOv2. Uses ImageNet mean and standard deviation.
|
||||
Note that this only resizes and normalizes the image, there is no center crop.
|
||||
|
||||
Args:
|
||||
img: The image.
|
||||
dim: The square dimension to resize the image. Typically 224 or 518.
|
||||
|
||||
Returns:
|
||||
A float32 tensor with shape (3, dim, dim).
|
||||
"""
|
||||
img = img.convert("RGB").resize((dim, dim)) # type: ignore
|
||||
t = image_to_tensor(img).squeeze()
|
||||
return normalize(t, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
|
||||
class DINOv2_small(ViT):
|
||||
|
|
117
src/refiners/training_utils/metrics.py
Normal file
117
src/refiners/training_utils/metrics.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from refiners.foundationals import dinov2
|
||||
|
||||
|
||||
def get_dinov2_representations(
|
||||
model: dinov2.ViT,
|
||||
dataloader: DataLoader[torch.Tensor],
|
||||
dtype: torch.dtype = torch.float64,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Get DINOV2 representations required to compute DINOv2-FD.
|
||||
|
||||
Args:
|
||||
model: The DINOv2 model to use.
|
||||
dataloader: A dataloader that returns batches of preprocessed images.
|
||||
dtype: The dtype to use for the representations. Use float64 for good precision.
|
||||
|
||||
Returns:
|
||||
A tensor with shape (batch, embedding_dim).
|
||||
"""
|
||||
r: list[torch.Tensor] = []
|
||||
for batch in dataloader:
|
||||
assert isinstance(batch, torch.Tensor)
|
||||
batch_size = batch.shape[0]
|
||||
assert batch.shape == (batch_size, 3, 224, 224)
|
||||
batch = batch.to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = model(batch)[:, 0] # only keep class embeddings
|
||||
|
||||
assert isinstance(pred, torch.Tensor)
|
||||
assert pred.shape == (batch_size, model.embedding_dim)
|
||||
|
||||
r.append(pred.to(dtype))
|
||||
|
||||
return torch.cat(r)
|
||||
|
||||
|
||||
def frechet_distance(reps_a: torch.Tensor, reps_b: torch.Tensor) -> float:
|
||||
"""
|
||||
Compute the Fréchet distance between two sets of representations.
|
||||
|
||||
Args:
|
||||
reps_a: First set of representations (typically the reference). Shape (batch, N).
|
||||
reps_a: Second set of representations (typically the test set). Shape (batch, N).
|
||||
"""
|
||||
assert reps_a.dim() == 2 and reps_b.dim() == 2, "representations must have shape (batch, N)"
|
||||
assert reps_a.shape[1] == reps_b.shape[1], "representations must have the same dimension"
|
||||
|
||||
mean_a = torch.mean(reps_a, dim=0)
|
||||
cov_a = torch.cov(reps_a.t())
|
||||
mean_b = torch.mean(reps_b, dim=0)
|
||||
cov_b = torch.cov(reps_b.t())
|
||||
|
||||
# The trace of the square root of a matrix is the sum of the square roots of its eigenvalues.
|
||||
trace = (torch.linalg.eigvals(cov_a.mm(cov_b)) ** 0.5).real.sum() # type: ignore
|
||||
assert isinstance(trace, torch.Tensor)
|
||||
|
||||
score = ((mean_a - mean_b) ** 2).sum() + cov_a.trace() + cov_b.trace() - 2 * trace
|
||||
return score.item()
|
||||
|
||||
|
||||
class DinoDataset(Dataset[torch.Tensor]):
|
||||
def __init__(self, path: str | Path) -> None:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
self.image_paths = sorted(path.glob("*.png"))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.image_paths)
|
||||
|
||||
def __getitem__(self, i: int) -> torch.Tensor:
|
||||
path = self.image_paths[i]
|
||||
img = Image.open(path) # type: ignore
|
||||
return dinov2.preprocess(img)
|
||||
|
||||
|
||||
def dinov2_frechet_distance(
|
||||
dataset_a: Dataset[torch.Tensor] | str | Path,
|
||||
dataset_b: Dataset[torch.Tensor] | str | Path,
|
||||
model: dinov2.ViT,
|
||||
batch_size: int = 64,
|
||||
dtype: torch.dtype = torch.float64,
|
||||
) -> float:
|
||||
"""
|
||||
Compute DINOv2-based Fréchet Distance between two datasets.
|
||||
|
||||
There may be small discrepancies with other implementations due to the fact that DINOv2 in Refiners
|
||||
uses the new style interpolation whereas DINOv2-FD historically uses the legacy implementation
|
||||
(see https://github.com/facebookresearch/dinov2/pull/378)
|
||||
|
||||
Args:
|
||||
dataset_a: First dataset (typically the reference). Can also be a path to a directory of PNG images.
|
||||
If a dataset is passed, it must preprocess the data using `dinov2.preprocess`.
|
||||
dataset_b: Second dataset (typically the test set). See `dataset_a` for details. Size can be different.
|
||||
model: The DINOv2 model to use.
|
||||
batch_size: The batch size to use.
|
||||
dtype: The dtype to use for the representations. Use float64 for good precision.
|
||||
"""
|
||||
|
||||
if not isinstance(dataset_a, Dataset):
|
||||
dataset_a = DinoDataset(dataset_a)
|
||||
if not isinstance(dataset_b, Dataset):
|
||||
dataset_b = DinoDataset(dataset_b)
|
||||
|
||||
dataloader_a = DataLoader(dataset_a, batch_size=batch_size, shuffle=False)
|
||||
dataloader_b = DataLoader(dataset_b, batch_size=batch_size, shuffle=False)
|
||||
|
||||
reps_a = get_dinov2_representations(model, dataloader_a, dtype)
|
||||
reps_b = get_dinov2_representations(model, dataloader_b, dtype)
|
||||
|
||||
return frechet_distance(reps_a, reps_b)
|
|
@ -6,6 +6,9 @@ from pytest import fixture
|
|||
|
||||
PARENT_PATH = Path(__file__).parent
|
||||
|
||||
collect_ignore = ["weights", "repos", "datasets"]
|
||||
collect_ignore_glob = ["*_ref"]
|
||||
|
||||
|
||||
@fixture(scope="session")
|
||||
def test_device() -> torch.device:
|
||||
|
@ -21,6 +24,12 @@ def test_weights_path() -> Path:
|
|||
return Path(from_env) if from_env else PARENT_PATH / "weights"
|
||||
|
||||
|
||||
@fixture(scope="session")
|
||||
def test_datasets_path() -> Path:
|
||||
from_env = os.getenv("REFINERS_TEST_DATASETS_DIR")
|
||||
return Path(from_env) if from_env else PARENT_PATH / "datasets"
|
||||
|
||||
|
||||
@fixture(scope="session")
|
||||
def test_repos_path() -> Path:
|
||||
from_env = os.getenv("REFINERS_TEST_REPOS_DIR")
|
||||
|
|
67
tests/training_utils/test_metrics.py
Normal file
67
tests/training_utils/test_metrics.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
from pathlib import Path
|
||||
from warnings import warn
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import CIFAR10 # type: ignore
|
||||
|
||||
from refiners.foundationals import dinov2
|
||||
from refiners.training_utils.metrics import dinov2_frechet_distance
|
||||
|
||||
|
||||
class CifarDataset(Dataset[torch.Tensor]):
|
||||
def __init__(self, ds: Dataset[list[torch.Tensor]], max_len: int = 512) -> None:
|
||||
self.ds = ds
|
||||
ds_length = len(self.ds) # type: ignore
|
||||
self.length = min(ds_length, max_len)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, i: int) -> torch.Tensor:
|
||||
return self.ds[i][0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dinov2_l(
|
||||
test_weights_path: Path,
|
||||
test_device: torch.device,
|
||||
) -> dinov2.DINOv2_large:
|
||||
weights = test_weights_path / f"dinov2_vitl14_pretrain.safetensors"
|
||||
if not weights.is_file():
|
||||
warn(f"could not find weights at {weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
model = dinov2.DINOv2_large(device=test_device)
|
||||
model.load_from_safetensors(weights)
|
||||
return model
|
||||
|
||||
|
||||
def test_dinov2_frechet_distance(test_datasets_path: Path, dinov2_l: dinov2.DINOv2_large) -> None:
|
||||
path = str(test_datasets_path / "CIFAR10")
|
||||
|
||||
ds_train = CifarDataset(
|
||||
CIFAR10(
|
||||
root=path,
|
||||
train=True,
|
||||
download=True,
|
||||
transform=dinov2.preprocess,
|
||||
)
|
||||
)
|
||||
|
||||
ds_test = CifarDataset(
|
||||
CIFAR10(
|
||||
root=path,
|
||||
train=False,
|
||||
download=True,
|
||||
transform=dinov2.preprocess,
|
||||
)
|
||||
)
|
||||
|
||||
# Computed using dgm-eval (https://github.com/layer6ai-labs/dgm-eval)
|
||||
# with interpolate_offset=0 and random_sample=False.
|
||||
expected_d = 837.978
|
||||
|
||||
d = dinov2_frechet_distance(ds_train, ds_test, dinov2_l)
|
||||
assert expected_d - 1e-2 < d < expected_d + 1e-2
|
Loading…
Reference in a new issue