mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 22:58:45 +00:00
add DINOv2-FD metric
This commit is contained in:
parent
c529006d13
commit
09af570b23
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -12,6 +12,7 @@ venv/
|
||||||
# tests' model weights
|
# tests' model weights
|
||||||
tests/weights/
|
tests/weights/
|
||||||
tests/repos/
|
tests/repos/
|
||||||
|
tests/datasets/
|
||||||
|
|
||||||
# ruff
|
# ruff
|
||||||
.ruff_cache
|
.ruff_cache
|
||||||
|
|
|
@ -5,6 +5,7 @@ from .dinov2 import (
|
||||||
DINOv2_large_reg,
|
DINOv2_large_reg,
|
||||||
DINOv2_small,
|
DINOv2_small,
|
||||||
DINOv2_small_reg,
|
DINOv2_small_reg,
|
||||||
|
preprocess,
|
||||||
)
|
)
|
||||||
from .vit import ViT
|
from .vit import ViT
|
||||||
|
|
||||||
|
@ -16,4 +17,5 @@ __all__ = [
|
||||||
"DINOv2_small",
|
"DINOv2_small",
|
||||||
"DINOv2_small_reg",
|
"DINOv2_small_reg",
|
||||||
"ViT",
|
"ViT",
|
||||||
|
"preprocess",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,9 +1,25 @@
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import image_to_tensor, normalize
|
||||||
from refiners.foundationals.dinov2.vit import ViT
|
from refiners.foundationals.dinov2.vit import ViT
|
||||||
|
|
||||||
# TODO: add preprocessing logic like
|
|
||||||
# https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/data/transforms.py#L77
|
def preprocess(img: Image.Image, dim: int = 224) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Preprocess an image for use with DINOv2. Uses ImageNet mean and standard deviation.
|
||||||
|
Note that this only resizes and normalizes the image, there is no center crop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: The image.
|
||||||
|
dim: The square dimension to resize the image. Typically 224 or 518.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A float32 tensor with shape (3, dim, dim).
|
||||||
|
"""
|
||||||
|
img = img.convert("RGB").resize((dim, dim)) # type: ignore
|
||||||
|
t = image_to_tensor(img).squeeze()
|
||||||
|
return normalize(t, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
|
||||||
|
|
||||||
class DINOv2_small(ViT):
|
class DINOv2_small(ViT):
|
||||||
|
|
117
src/refiners/training_utils/metrics.py
Normal file
117
src/refiners/training_utils/metrics.py
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
from refiners.foundationals import dinov2
|
||||||
|
|
||||||
|
|
||||||
|
def get_dinov2_representations(
|
||||||
|
model: dinov2.ViT,
|
||||||
|
dataloader: DataLoader[torch.Tensor],
|
||||||
|
dtype: torch.dtype = torch.float64,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Get DINOV2 representations required to compute DINOv2-FD.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The DINOv2 model to use.
|
||||||
|
dataloader: A dataloader that returns batches of preprocessed images.
|
||||||
|
dtype: The dtype to use for the representations. Use float64 for good precision.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor with shape (batch, embedding_dim).
|
||||||
|
"""
|
||||||
|
r: list[torch.Tensor] = []
|
||||||
|
for batch in dataloader:
|
||||||
|
assert isinstance(batch, torch.Tensor)
|
||||||
|
batch_size = batch.shape[0]
|
||||||
|
assert batch.shape == (batch_size, 3, 224, 224)
|
||||||
|
batch = batch.to(model.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pred = model(batch)[:, 0] # only keep class embeddings
|
||||||
|
|
||||||
|
assert isinstance(pred, torch.Tensor)
|
||||||
|
assert pred.shape == (batch_size, model.embedding_dim)
|
||||||
|
|
||||||
|
r.append(pred.to(dtype))
|
||||||
|
|
||||||
|
return torch.cat(r)
|
||||||
|
|
||||||
|
|
||||||
|
def frechet_distance(reps_a: torch.Tensor, reps_b: torch.Tensor) -> float:
|
||||||
|
"""
|
||||||
|
Compute the Fréchet distance between two sets of representations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reps_a: First set of representations (typically the reference). Shape (batch, N).
|
||||||
|
reps_a: Second set of representations (typically the test set). Shape (batch, N).
|
||||||
|
"""
|
||||||
|
assert reps_a.dim() == 2 and reps_b.dim() == 2, "representations must have shape (batch, N)"
|
||||||
|
assert reps_a.shape[1] == reps_b.shape[1], "representations must have the same dimension"
|
||||||
|
|
||||||
|
mean_a = torch.mean(reps_a, dim=0)
|
||||||
|
cov_a = torch.cov(reps_a.t())
|
||||||
|
mean_b = torch.mean(reps_b, dim=0)
|
||||||
|
cov_b = torch.cov(reps_b.t())
|
||||||
|
|
||||||
|
# The trace of the square root of a matrix is the sum of the square roots of its eigenvalues.
|
||||||
|
trace = (torch.linalg.eigvals(cov_a.mm(cov_b)) ** 0.5).real.sum() # type: ignore
|
||||||
|
assert isinstance(trace, torch.Tensor)
|
||||||
|
|
||||||
|
score = ((mean_a - mean_b) ** 2).sum() + cov_a.trace() + cov_b.trace() - 2 * trace
|
||||||
|
return score.item()
|
||||||
|
|
||||||
|
|
||||||
|
class DinoDataset(Dataset[torch.Tensor]):
|
||||||
|
def __init__(self, path: str | Path) -> None:
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = Path(path)
|
||||||
|
self.image_paths = sorted(path.glob("*.png"))
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.image_paths)
|
||||||
|
|
||||||
|
def __getitem__(self, i: int) -> torch.Tensor:
|
||||||
|
path = self.image_paths[i]
|
||||||
|
img = Image.open(path) # type: ignore
|
||||||
|
return dinov2.preprocess(img)
|
||||||
|
|
||||||
|
|
||||||
|
def dinov2_frechet_distance(
|
||||||
|
dataset_a: Dataset[torch.Tensor] | str | Path,
|
||||||
|
dataset_b: Dataset[torch.Tensor] | str | Path,
|
||||||
|
model: dinov2.ViT,
|
||||||
|
batch_size: int = 64,
|
||||||
|
dtype: torch.dtype = torch.float64,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Compute DINOv2-based Fréchet Distance between two datasets.
|
||||||
|
|
||||||
|
There may be small discrepancies with other implementations due to the fact that DINOv2 in Refiners
|
||||||
|
uses the new style interpolation whereas DINOv2-FD historically uses the legacy implementation
|
||||||
|
(see https://github.com/facebookresearch/dinov2/pull/378)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_a: First dataset (typically the reference). Can also be a path to a directory of PNG images.
|
||||||
|
If a dataset is passed, it must preprocess the data using `dinov2.preprocess`.
|
||||||
|
dataset_b: Second dataset (typically the test set). See `dataset_a` for details. Size can be different.
|
||||||
|
model: The DINOv2 model to use.
|
||||||
|
batch_size: The batch size to use.
|
||||||
|
dtype: The dtype to use for the representations. Use float64 for good precision.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(dataset_a, Dataset):
|
||||||
|
dataset_a = DinoDataset(dataset_a)
|
||||||
|
if not isinstance(dataset_b, Dataset):
|
||||||
|
dataset_b = DinoDataset(dataset_b)
|
||||||
|
|
||||||
|
dataloader_a = DataLoader(dataset_a, batch_size=batch_size, shuffle=False)
|
||||||
|
dataloader_b = DataLoader(dataset_b, batch_size=batch_size, shuffle=False)
|
||||||
|
|
||||||
|
reps_a = get_dinov2_representations(model, dataloader_a, dtype)
|
||||||
|
reps_b = get_dinov2_representations(model, dataloader_b, dtype)
|
||||||
|
|
||||||
|
return frechet_distance(reps_a, reps_b)
|
|
@ -6,6 +6,9 @@ from pytest import fixture
|
||||||
|
|
||||||
PARENT_PATH = Path(__file__).parent
|
PARENT_PATH = Path(__file__).parent
|
||||||
|
|
||||||
|
collect_ignore = ["weights", "repos", "datasets"]
|
||||||
|
collect_ignore_glob = ["*_ref"]
|
||||||
|
|
||||||
|
|
||||||
@fixture(scope="session")
|
@fixture(scope="session")
|
||||||
def test_device() -> torch.device:
|
def test_device() -> torch.device:
|
||||||
|
@ -21,6 +24,12 @@ def test_weights_path() -> Path:
|
||||||
return Path(from_env) if from_env else PARENT_PATH / "weights"
|
return Path(from_env) if from_env else PARENT_PATH / "weights"
|
||||||
|
|
||||||
|
|
||||||
|
@fixture(scope="session")
|
||||||
|
def test_datasets_path() -> Path:
|
||||||
|
from_env = os.getenv("REFINERS_TEST_DATASETS_DIR")
|
||||||
|
return Path(from_env) if from_env else PARENT_PATH / "datasets"
|
||||||
|
|
||||||
|
|
||||||
@fixture(scope="session")
|
@fixture(scope="session")
|
||||||
def test_repos_path() -> Path:
|
def test_repos_path() -> Path:
|
||||||
from_env = os.getenv("REFINERS_TEST_REPOS_DIR")
|
from_env = os.getenv("REFINERS_TEST_REPOS_DIR")
|
||||||
|
|
67
tests/training_utils/test_metrics.py
Normal file
67
tests/training_utils/test_metrics.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision.datasets import CIFAR10 # type: ignore
|
||||||
|
|
||||||
|
from refiners.foundationals import dinov2
|
||||||
|
from refiners.training_utils.metrics import dinov2_frechet_distance
|
||||||
|
|
||||||
|
|
||||||
|
class CifarDataset(Dataset[torch.Tensor]):
|
||||||
|
def __init__(self, ds: Dataset[list[torch.Tensor]], max_len: int = 512) -> None:
|
||||||
|
self.ds = ds
|
||||||
|
ds_length = len(self.ds) # type: ignore
|
||||||
|
self.length = min(ds_length, max_len)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, i: int) -> torch.Tensor:
|
||||||
|
return self.ds[i][0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def dinov2_l(
|
||||||
|
test_weights_path: Path,
|
||||||
|
test_device: torch.device,
|
||||||
|
) -> dinov2.DINOv2_large:
|
||||||
|
weights = test_weights_path / f"dinov2_vitl14_pretrain.safetensors"
|
||||||
|
if not weights.is_file():
|
||||||
|
warn(f"could not find weights at {weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
|
model = dinov2.DINOv2_large(device=test_device)
|
||||||
|
model.load_from_safetensors(weights)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def test_dinov2_frechet_distance(test_datasets_path: Path, dinov2_l: dinov2.DINOv2_large) -> None:
|
||||||
|
path = str(test_datasets_path / "CIFAR10")
|
||||||
|
|
||||||
|
ds_train = CifarDataset(
|
||||||
|
CIFAR10(
|
||||||
|
root=path,
|
||||||
|
train=True,
|
||||||
|
download=True,
|
||||||
|
transform=dinov2.preprocess,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
ds_test = CifarDataset(
|
||||||
|
CIFAR10(
|
||||||
|
root=path,
|
||||||
|
train=False,
|
||||||
|
download=True,
|
||||||
|
transform=dinov2.preprocess,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Computed using dgm-eval (https://github.com/layer6ai-labs/dgm-eval)
|
||||||
|
# with interpolate_offset=0 and random_sample=False.
|
||||||
|
expected_d = 837.978
|
||||||
|
|
||||||
|
d = dinov2_frechet_distance(ds_train, ds_test, dinov2_l)
|
||||||
|
assert expected_d - 1e-2 < d < expected_d + 1e-2
|
Loading…
Reference in a new issue