refiners/tests/foundationals/segment_anything/test_hq_sam.py
2024-03-23 21:58:32 +01:00

349 lines
14 KiB
Python

from pathlib import Path
from typing import cast
from warnings import warn
import numpy as np
import pytest
import torch
from PIL import Image
from segment_anything_hq import ( # type: ignore
SamPredictor as SamPredictorHQ,
sam_model_registry as sam_model_registry_hq, # type: ignore
)
from segment_anything_hq.modeling.sam import Sam # type: ignore
from tests.foundationals.segment_anything.utils import FacebookSAM, FacebookSAMPredictorHQ, SAMPrompt
from torch import optim
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
from refiners.foundationals.segment_anything.hq_sam import (
CompressViTFeat,
EmbeddingEncoder,
HQSAMAdapter,
HQTokenMLP,
MaskDecoderTokensExtender,
PredictionsPostProc,
)
from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH
@pytest.fixture(scope="module")
def one_prompt() -> SAMPrompt:
return SAMPrompt(box_points=[[(4, 13), (1007, 1023)]])
@pytest.fixture(scope="module")
def tennis(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "tennis.png").convert("RGB")
@pytest.fixture(scope="module")
def hq_adapter_weights(test_weights_path: Path) -> Path:
"""Path to the HQ adapter weights in Refiners format"""
refiners_hq_adapter_sam_weights = test_weights_path / "refiners-sam-hq-vit-h.safetensors"
if not refiners_hq_adapter_sam_weights.is_file():
warn(f"Test weights not found at {refiners_hq_adapter_sam_weights}, skipping")
pytest.skip(allow_module_level=True)
return refiners_hq_adapter_sam_weights
@pytest.fixture
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
# HQSAMAdapter is designed to be used with single-output only, hence multimask_output=False.
sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
return sam_h
@pytest.fixture(scope="module")
def reference_hq_adapter_weights(test_weights_path: Path) -> Path:
"""Path to the HQ adapter weights in default format"""
reference_hq_adapter_sam_weights = test_weights_path / "sam_hq_vit_h.pth"
if not reference_hq_adapter_sam_weights.is_file():
warn(f"Test weights not found at {reference_hq_adapter_sam_weights}, skipping")
pytest.skip(allow_module_level=True)
return reference_hq_adapter_sam_weights
@pytest.fixture(scope="module")
def reference_sam_h(reference_hq_adapter_weights: Path, test_device: torch.device) -> FacebookSAM:
sam_h = cast(FacebookSAM, sam_model_registry_hq["vit_h"](checkpoint=reference_hq_adapter_weights))
return sam_h.to(device=test_device)
@pytest.fixture(scope="module")
def reference_sam_h_predictor(reference_sam_h: FacebookSAM) -> FacebookSAMPredictorHQ:
predictor = SamPredictorHQ(cast(Sam, reference_sam_h))
return cast(FacebookSAMPredictorHQ, predictor)
def test_inject_eject() -> None:
sam_h = SegmentAnythingH(multimask_output=False)
initial_repr = repr(sam_h)
adapter = HQSAMAdapter(sam_h)
assert repr(sam_h) == initial_repr
adapter.inject()
assert repr(sam_h) != initial_repr
adapter.eject()
assert repr(sam_h) == initial_repr
def test_multimask_forbidden() -> None:
with pytest.raises(NotImplementedError, match="not supported"):
HQSAMAdapter(target=SegmentAnythingH(multimask_output=True))
def test_output_shape_hq_adapter(tennis: Image.Image, one_prompt: SAMPrompt) -> None:
sam_h = SegmentAnythingH(multimask_output=False)
HQSAMAdapter(sam_h).inject()
high_res_masks, iou_predictions, low_res_masks = sam_h.predict(tennis, **one_prompt.__dict__)
assert high_res_masks.shape == (1, 1, 1024, 1024)
assert iou_predictions.shape == (1, 1)
assert low_res_masks.shape == (1, 1, 256, 256)
def test_mask_decoder_tokens_extender() -> None:
sam_h = SegmentAnythingH(multimask_output=False)
sam_h.requires_grad_(False)
# MaskDecoderTokens requires image_embedding context to be set
image_embedding = torch.randn(2, 256, 64, 64)
sam_h.mask_decoder.set_image_embedding(image_embedding)
HQSAMAdapter(sam_h).inject()
mask_decoder_tokens = sam_h.ensure_find(MaskDecoderTokensExtender)
tokens_before = mask_decoder_tokens()
assert tokens_before.shape == torch.Size([2, 6, 256])
for p in mask_decoder_tokens.parameters():
match p.shape:
case torch.Size([5, 256]):
assert not p.requires_grad
case torch.Size([1, 256]):
assert p.requires_grad
case _:
raise ValueError
optimizer = optim.SGD(mask_decoder_tokens.parameters(), lr=10)
optimizer.zero_grad()
ones = torch.ones_like(tokens_before)
loss = torch.nn.functional.mse_loss(tokens_before, ones)
loss.backward() # type: ignore
optimizer.step()
tokens_after = mask_decoder_tokens()
assert torch.equal(tokens_before[:, :5, :], tokens_after[:, :5, :])
assert not torch.equal(tokens_before[:, 5, :], tokens_after[:, 5, :])
@no_grad()
def test_early_vit_embedding(
sam_h: SegmentAnythingH,
hq_adapter_weights: Path,
reference_sam_h: FacebookSAM,
tennis: Image.Image,
) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024)))
_ = sam_h.image_encoder(image_tensor.to(sam_h.device))
early_vit_embedding_refiners = sam_h.use_context(context_name="hq_sam")["early_vit_embedding"]
_, intermediate_embeddings = reference_sam_h.image_encoder(image_tensor.to(reference_sam_h.device))
early_vit_embedding = intermediate_embeddings[0]
assert torch.equal(early_vit_embedding, early_vit_embedding_refiners)
def test_tokens(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
mask_decoder_tokens_extender = sam_h.mask_decoder.ensure_find(MaskDecoderTokensExtender)
# HF Token (1, 256)
assert torch.equal(reference_sam_h.mask_decoder.hf_token.weight, mask_decoder_tokens_extender.hq_token.weight)
# Regular Tokens (5, 256)
assert torch.equal(
torch.cat([reference_sam_h.mask_decoder.iou_token.weight, reference_sam_h.mask_decoder.mask_tokens.weight]),
mask_decoder_tokens_extender.regular_tokens.weight,
)
@no_grad()
def test_compress_vit_feat(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype)
sam_h.set_context(context="hq_sam", value={"early_vit_embedding": early_vit_embedding})
refiners_output = sam_h.ensure_find(CompressViTFeat)()
reference_output = reference_sam_h.mask_decoder.compress_vit_feat(early_vit_embedding.permute(0, 3, 1, 2))
assert torch.equal(refiners_output, reference_output)
@no_grad()
def test_embedding_encoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
x = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype)
sam_h.set_context(context="mask_decoder", value={"image_embedding": x})
refiners_output = sam_h.ensure_find(EmbeddingEncoder)()
reference_output = reference_sam_h.mask_decoder.embedding_encoder(x)
assert torch.equal(refiners_output, reference_output)
@no_grad()
def test_hq_token_mlp(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
x = torch.randn(1, 6, 256, device=sam_h.device, dtype=sam_h.dtype)
refiners_output = sam_h.ensure_find(HQTokenMLP)(x)
reference_output = reference_sam_h.mask_decoder.hf_mlp(x[:, -1, :]).unsqueeze(0)
assert torch.equal(refiners_output, reference_output)
@pytest.mark.parametrize("hq_mask_only", [True, False])
def test_predictor(
sam_h: SegmentAnythingH,
hq_adapter_weights: Path,
hq_mask_only: bool,
reference_sam_h_predictor: FacebookSAMPredictorHQ,
tennis: Image.Image,
one_prompt: SAMPrompt,
) -> None:
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
adapter.hq_mask_only = hq_mask_only
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
# Refiners
high_res_masks, iou_predictions, low_res_masks = sam_h.predict(tennis, **one_prompt.__dict__)
refiners_high_res_mask_hq = high_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
refiners_low_res_mask_hq = low_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
iou_predictions = iou_predictions[0, :].to(dtype=torch.float32).detach().cpu()
# Reference
reference_sam_h_predictor.set_image(np.array(tennis))
predictor_prompt = one_prompt.__dict__["box_points"]
masks_np, iou_predictions_np, low_res_masks_np = reference_sam_h_predictor.predict(
box=np.array(predictor_prompt).flatten(),
multimask_output=False,
hq_token_only=hq_mask_only,
)
reference_high_res_mask_hq = torch.from_numpy(masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
iou_predictions_np = torch.from_numpy(iou_predictions_np).to(dtype=torch.float32) # type: ignore
# NOTE: Diff on logits is relatively high,
# see test_predictor_equal for a stricter version
assert torch.allclose(
reference_low_res_mask_hq,
refiners_low_res_mask_hq,
atol=4e-3,
)
assert (
torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 1
) # The diff on the logits above leads to an absolute diff of 1 pixel on the high res masks
assert torch.allclose(
iou_predictions_np,
torch.max(iou_predictions),
atol=1e-5,
)
@pytest.mark.parametrize("hq_mask_only", [True, False])
def test_predictor_equal(
sam_h: SegmentAnythingH,
hq_adapter_weights: Path,
hq_mask_only: bool,
reference_sam_h_predictor: FacebookSAMPredictorHQ,
tennis: Image.Image,
one_prompt: SAMPrompt,
) -> None:
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
adapter.hq_mask_only = hq_mask_only
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
# See in test_sam.py test_predictor_resized_single_output
# to do torch.equal we need to resize the image before
# and to use image_embedding as input
size = (1024, 1024)
resized_tennis = tennis.resize(size)
# Reference
reference_sam_h_predictor.set_image(np.array(resized_tennis))
predictor_prompt = one_prompt.__dict__["box_points"]
masks_np, _, low_res_masks_np = reference_sam_h_predictor.predict(
box=np.array(predictor_prompt).flatten(),
multimask_output=False,
hq_token_only=hq_mask_only,
)
reference_high_res_mask_hq = torch.from_numpy(masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
# Refiners
# We bypass the refiners ViT by using directly the image features and interm_features
# from the reference implementation: this gives the ability to do torch.equal
reference_image_embedding = ImageEmbedding(features=reference_sam_h_predictor.features, original_image_size=size)
adapter.set_context("hq_sam", {"early_vit_embedding": reference_sam_h_predictor.interm_features[0]})
high_res_masks, _, low_res_masks = sam_h.predict(reference_image_embedding, **one_prompt.__dict__)
refiners_high_res_mask_hq = high_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
refiners_low_res_mask_hq = low_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
assert torch.equal(
reference_low_res_mask_hq,
refiners_low_res_mask_hq,
)
assert torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() == 0
@no_grad()
def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
batch_size = 5
image_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
mask_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
dense_positional_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(
batch_size, 1, 1, 1
)
point_embedding = torch.randn(1, 2, 256, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1)
early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype).repeat(
batch_size, 1, 1, 1
)
sam_h.mask_decoder.set_image_embedding(image_embedding)
sam_h.mask_decoder.set_mask_embedding(mask_embedding)
sam_h.mask_decoder.set_point_embedding(point_embedding)
sam_h.mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
sam_h.mask_decoder.set_context(
context="hq_sam", value={"early_vit_embedding": early_vit_embedding.to(sam_h.device, sam_h.dtype)}
)
mask_prediction, iou_prediction = sam_h.mask_decoder()
assert mask_prediction.shape == (batch_size, 1, 256, 256)
assert iou_prediction.shape == (batch_size, 1)
assert torch.equal(mask_prediction[0], mask_prediction[1])