refiners/tests/foundationals/segment_anything/test_hq_sam.py
2024-03-21 15:36:55 +01:00

297 lines
12 KiB
Python

from pathlib import Path
from typing import cast
from warnings import warn
import numpy as np
import pytest
import torch
from PIL import Image
from segment_anything_hq import ( # type: ignore
SamPredictor as SamPredictorHQ,
sam_model_registry as sam_model_registry_hq, # type: ignore
)
from segment_anything_hq.modeling.sam import Sam # type: ignore
from tests.foundationals.segment_anything.utils import FacebookSAM, FacebookSAMPredictorHQ, SAMPrompt
from torch import optim
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
from refiners.foundationals.segment_anything.hq_sam import (
CompressViTFeat,
EmbeddingEncoder,
HQSAMAdapter,
HQTokenMLP,
MaskDecoderTokensExtender,
PredictionsPostProc,
)
from refiners.foundationals.segment_anything.model import SegmentAnythingH
@pytest.fixture(scope="module")
def one_prompt() -> SAMPrompt:
return SAMPrompt(box_points=[[(4, 13), (1007, 1023)]])
@pytest.fixture(scope="module")
def tennis(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "tennis.png").convert("RGB")
@pytest.fixture(scope="module")
def hq_adapter_weights(test_weights_path: Path) -> Path:
"""Path to the HQ adapter weights in Refiners format"""
refiners_hq_adapter_sam_weights = test_weights_path / "refiners-sam-hq-vit-h.safetensors"
if not refiners_hq_adapter_sam_weights.is_file():
warn(f"Test weights not found at {refiners_hq_adapter_sam_weights}, skipping")
pytest.skip(allow_module_level=True)
return refiners_hq_adapter_sam_weights
@pytest.fixture
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
# HQSAMAdapter is designed to be used with single-output only, hence multimask_output=False.
sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
return sam_h
@pytest.fixture(scope="module")
def reference_hq_adapter_weights(test_weights_path: Path) -> Path:
"""Path to the HQ adapter weights in default format"""
reference_hq_adapter_sam_weights = test_weights_path / "sam_hq_vit_h.pth"
if not reference_hq_adapter_sam_weights.is_file():
warn(f"Test weights not found at {reference_hq_adapter_sam_weights}, skipping")
pytest.skip(allow_module_level=True)
return reference_hq_adapter_sam_weights
@pytest.fixture(scope="module")
def reference_sam_h(reference_hq_adapter_weights: Path, test_device: torch.device) -> FacebookSAM:
sam_h = cast(FacebookSAM, sam_model_registry_hq["vit_h"](checkpoint=reference_hq_adapter_weights))
return sam_h.to(device=test_device)
@pytest.fixture(scope="module")
def reference_sam_h_predictor(reference_sam_h: FacebookSAM) -> FacebookSAMPredictorHQ:
predictor = SamPredictorHQ(cast(Sam, reference_sam_h))
return cast(FacebookSAMPredictorHQ, predictor)
def test_inject_eject() -> None:
sam_h = SegmentAnythingH(multimask_output=False)
initial_repr = repr(sam_h)
adapter = HQSAMAdapter(sam_h)
assert repr(sam_h) == initial_repr
adapter.inject()
assert repr(sam_h) != initial_repr
adapter.eject()
assert repr(sam_h) == initial_repr
def test_multimask_forbidden() -> None:
with pytest.raises(NotImplementedError, match="not supported"):
HQSAMAdapter(target=SegmentAnythingH(multimask_output=True))
def test_output_shape_hq_adapter(tennis: Image.Image, one_prompt: SAMPrompt) -> None:
sam_h = SegmentAnythingH(multimask_output=False)
HQSAMAdapter(sam_h).inject()
high_res_masks, iou_predictions, low_res_masks = sam_h.predict(tennis, **one_prompt.__dict__)
assert high_res_masks.shape == (1, 1, 1024, 1024)
assert iou_predictions.shape == (1, 1)
assert low_res_masks.shape == (1, 1, 256, 256)
def test_mask_decoder_tokens_extender() -> None:
sam_h = SegmentAnythingH(multimask_output=False)
sam_h.requires_grad_(False)
# MaskDecoderTokens requires image_embedding context to be set
image_embedding = torch.randn(2, 256, 64, 64)
sam_h.mask_decoder.set_image_embedding(image_embedding)
HQSAMAdapter(sam_h).inject()
mask_decoder_tokens = sam_h.ensure_find(MaskDecoderTokensExtender)
tokens_before = mask_decoder_tokens()
assert tokens_before.shape == torch.Size([2, 6, 256])
for p in mask_decoder_tokens.parameters():
match p.shape:
case torch.Size([5, 256]):
assert not p.requires_grad
case torch.Size([1, 256]):
assert p.requires_grad
case _:
raise ValueError
optimizer = optim.SGD(mask_decoder_tokens.parameters(), lr=10)
optimizer.zero_grad()
ones = torch.ones_like(tokens_before)
loss = torch.nn.functional.mse_loss(tokens_before, ones)
loss.backward() # type: ignore
optimizer.step()
tokens_after = mask_decoder_tokens()
assert torch.equal(tokens_before[:, :5, :], tokens_after[:, :5, :])
assert not torch.equal(tokens_before[:, 5, :], tokens_after[:, 5, :])
@no_grad()
def test_early_vit_embedding(
sam_h: SegmentAnythingH,
hq_adapter_weights: Path,
reference_sam_h: FacebookSAM,
tennis: Image.Image,
) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024)))
_ = sam_h.image_encoder(image_tensor.to(sam_h.device))
early_vit_embedding_refiners = sam_h.use_context(context_name="hq_sam")["early_vit_embedding"]
_, intermediate_embeddings = reference_sam_h.image_encoder(image_tensor.to(reference_sam_h.device))
early_vit_embedding = intermediate_embeddings[0]
assert torch.equal(early_vit_embedding, early_vit_embedding_refiners)
def test_tokens(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
mask_decoder_tokens_extender = sam_h.mask_decoder.ensure_find(MaskDecoderTokensExtender)
# HF Token (1, 256)
assert torch.equal(reference_sam_h.mask_decoder.hf_token.weight, mask_decoder_tokens_extender.hq_token.weight)
# Regular Tokens (5, 256)
assert torch.equal(
torch.cat([reference_sam_h.mask_decoder.iou_token.weight, reference_sam_h.mask_decoder.mask_tokens.weight]),
mask_decoder_tokens_extender.regular_tokens.weight,
)
@no_grad()
def test_compress_vit_feat(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype)
sam_h.set_context(context="hq_sam", value={"early_vit_embedding": early_vit_embedding})
refiners_output = sam_h.ensure_find(CompressViTFeat)()
reference_output = reference_sam_h.mask_decoder.compress_vit_feat(early_vit_embedding.permute(0, 3, 1, 2))
assert torch.equal(refiners_output, reference_output)
@no_grad()
def test_embedding_encoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
x = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype)
sam_h.set_context(context="mask_decoder", value={"image_embedding": x})
refiners_output = sam_h.ensure_find(EmbeddingEncoder)()
reference_output = reference_sam_h.mask_decoder.embedding_encoder(x)
assert torch.equal(refiners_output, reference_output)
@no_grad()
def test_hq_token_mlp(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
x = torch.randn(1, 6, 256, device=sam_h.device, dtype=sam_h.dtype)
refiners_output = sam_h.ensure_find(HQTokenMLP)(x)
reference_output = reference_sam_h.mask_decoder.hf_mlp(x[:, -1, :]).unsqueeze(0)
assert torch.equal(refiners_output, reference_output)
@pytest.mark.parametrize("hq_mask_only", [True, False])
def test_predictor(
sam_h: SegmentAnythingH,
hq_adapter_weights: Path,
hq_mask_only: bool,
reference_sam_h_predictor: FacebookSAMPredictorHQ,
tennis: Image.Image,
one_prompt: SAMPrompt,
) -> None:
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
adapter.hq_mask_only = hq_mask_only
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
# Refiners
high_res_masks, iou_predictions, low_res_masks = sam_h.predict(tennis, **one_prompt.__dict__)
refiners_high_res_mask_hq = high_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
refiners_low_res_mask_hq = low_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
iou_predictions = iou_predictions[0, :].to(dtype=torch.float32).detach().cpu()
# Reference
reference_sam_h_predictor.set_image(np.array(tennis))
predictor_prompt = one_prompt.__dict__["box_points"]
masks_np, iou_predictions_np, low_res_masks_np = reference_sam_h_predictor.predict(
box=np.array(predictor_prompt).flatten(),
multimask_output=False,
hq_token_only=hq_mask_only,
)
reference_high_res_mask_hq = torch.from_numpy(masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
iou_predictions_np = torch.from_numpy(iou_predictions_np).to(dtype=torch.float32) # type: ignore
# NOTE: Diff on logits is relatively high, but on the same scale / even lower than base SAM logits diff (6e-3)
# See https://github.com/finegrain-ai/refiners/blob/c6b5eb24a179d48e4542d94684a70c5ef3142ab1/tests/foundationals/segment_anything/test_sam.py#L426
assert torch.allclose(
reference_low_res_mask_hq,
refiners_low_res_mask_hq,
atol=4e-3,
)
assert (
torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 1
) # The diff on the logits above leads to an absolute diff of 1 pixel on the high res masks
assert torch.allclose(
iou_predictions_np,
torch.max(iou_predictions),
atol=1e-5,
)
@no_grad()
def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
batch_size = 5
image_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
mask_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
dense_positional_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(
batch_size, 1, 1, 1
)
point_embedding = torch.randn(1, 2, 256, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1)
early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype).repeat(
batch_size, 1, 1, 1
)
sam_h.mask_decoder.set_image_embedding(image_embedding)
sam_h.mask_decoder.set_mask_embedding(mask_embedding)
sam_h.mask_decoder.set_point_embedding(point_embedding)
sam_h.mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
sam_h.mask_decoder.set_context(
context="hq_sam", value={"early_vit_embedding": early_vit_embedding.to(sam_h.device, sam_h.dtype)}
)
mask_prediction, iou_prediction = sam_h.mask_decoder()
assert mask_prediction.shape == (batch_size, 1, 256, 256)
assert iou_prediction.shape == (batch_size, 1)
assert torch.equal(mask_prediction[0], mask_prediction[1])