from pathlib import Path from typing import cast from warnings import warn import numpy as np import pytest import torch from PIL import Image from segment_anything_hq import ( # type: ignore SamPredictor as SamPredictorHQ, sam_model_registry as sam_model_registry_hq, # type: ignore ) from segment_anything_hq.modeling.sam import Sam # type: ignore from tests.foundationals.segment_anything.utils import FacebookSAM, FacebookSAMPredictorHQ, SAMPrompt from torch import optim from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad from refiners.foundationals.segment_anything.hq_sam import ( CompressViTFeat, EmbeddingEncoder, HQSAMAdapter, HQTokenMLP, MaskDecoderTokensExtender, PredictionsPostProc, ) from refiners.foundationals.segment_anything.model import SegmentAnythingH @pytest.fixture(scope="module") def one_prompt() -> SAMPrompt: return SAMPrompt(box_points=[[(4, 13), (1007, 1023)]]) @pytest.fixture(scope="module") def tennis(ref_path: Path) -> Image.Image: return Image.open(ref_path / "tennis.png").convert("RGB") @pytest.fixture(scope="module") def hq_adapter_weights(test_weights_path: Path) -> Path: """Path to the HQ adapter weights in Refiners format""" refiners_hq_adapter_sam_weights = test_weights_path / "refiners-sam-hq-vit-h.safetensors" if not refiners_hq_adapter_sam_weights.is_file(): warn(f"Test weights not found at {refiners_hq_adapter_sam_weights}, skipping") pytest.skip(allow_module_level=True) return refiners_hq_adapter_sam_weights @pytest.fixture def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH: # HQSAMAdapter is designed to be used with single-output only, hence multimask_output=False. sam_h = SegmentAnythingH(multimask_output=False, device=test_device) sam_h.load_from_safetensors(tensors_path=sam_h_weights) return sam_h @pytest.fixture(scope="module") def reference_hq_adapter_weights(test_weights_path: Path) -> Path: """Path to the HQ adapter weights in default format""" reference_hq_adapter_sam_weights = test_weights_path / "sam_hq_vit_h.pth" if not reference_hq_adapter_sam_weights.is_file(): warn(f"Test weights not found at {reference_hq_adapter_sam_weights}, skipping") pytest.skip(allow_module_level=True) return reference_hq_adapter_sam_weights @pytest.fixture(scope="module") def reference_sam_h(reference_hq_adapter_weights: Path, test_device: torch.device) -> FacebookSAM: sam_h = cast(FacebookSAM, sam_model_registry_hq["vit_h"](checkpoint=reference_hq_adapter_weights)) return sam_h.to(device=test_device) @pytest.fixture(scope="module") def reference_sam_h_predictor(reference_sam_h: FacebookSAM) -> FacebookSAMPredictorHQ: predictor = SamPredictorHQ(cast(Sam, reference_sam_h)) return cast(FacebookSAMPredictorHQ, predictor) def test_inject_eject() -> None: sam_h = SegmentAnythingH(multimask_output=False) initial_repr = repr(sam_h) adapter = HQSAMAdapter(sam_h) assert repr(sam_h) == initial_repr adapter.inject() assert repr(sam_h) != initial_repr adapter.eject() assert repr(sam_h) == initial_repr def test_multimask_forbidden() -> None: with pytest.raises(NotImplementedError, match="not supported"): HQSAMAdapter(target=SegmentAnythingH(multimask_output=True)) def test_output_shape_hq_adapter(tennis: Image.Image, one_prompt: SAMPrompt) -> None: sam_h = SegmentAnythingH(multimask_output=False) HQSAMAdapter(sam_h).inject() high_res_masks, iou_predictions, low_res_masks = sam_h.predict(tennis, **one_prompt.__dict__) assert high_res_masks.shape == (1, 1, 1024, 1024) assert iou_predictions.shape == (1, 1) assert low_res_masks.shape == (1, 1, 256, 256) def test_mask_decoder_tokens_extender() -> None: sam_h = SegmentAnythingH(multimask_output=False) sam_h.requires_grad_(False) # MaskDecoderTokens requires image_embedding context to be set image_embedding = torch.randn(2, 256, 64, 64) sam_h.mask_decoder.set_image_embedding(image_embedding) HQSAMAdapter(sam_h).inject() mask_decoder_tokens = sam_h.ensure_find(MaskDecoderTokensExtender) tokens_before = mask_decoder_tokens() assert tokens_before.shape == torch.Size([2, 6, 256]) for p in mask_decoder_tokens.parameters(): match p.shape: case torch.Size([5, 256]): assert not p.requires_grad case torch.Size([1, 256]): assert p.requires_grad case _: raise ValueError optimizer = optim.SGD(mask_decoder_tokens.parameters(), lr=10) optimizer.zero_grad() ones = torch.ones_like(tokens_before) loss = torch.nn.functional.mse_loss(tokens_before, ones) loss.backward() # type: ignore optimizer.step() tokens_after = mask_decoder_tokens() assert torch.equal(tokens_before[:, :5, :], tokens_after[:, :5, :]) assert not torch.equal(tokens_before[:, 5, :], tokens_after[:, 5, :]) @no_grad() def test_early_vit_embedding( sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM, tennis: Image.Image, ) -> None: HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024))) _ = sam_h.image_encoder(image_tensor.to(sam_h.device)) early_vit_embedding_refiners = sam_h.use_context(context_name="hq_sam")["early_vit_embedding"] _, intermediate_embeddings = reference_sam_h.image_encoder(image_tensor.to(reference_sam_h.device)) early_vit_embedding = intermediate_embeddings[0] assert torch.equal(early_vit_embedding, early_vit_embedding_refiners) def test_tokens(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() mask_decoder_tokens_extender = sam_h.mask_decoder.ensure_find(MaskDecoderTokensExtender) # HF Token (1, 256) assert torch.equal(reference_sam_h.mask_decoder.hf_token.weight, mask_decoder_tokens_extender.hq_token.weight) # Regular Tokens (5, 256) assert torch.equal( torch.cat([reference_sam_h.mask_decoder.iou_token.weight, reference_sam_h.mask_decoder.mask_tokens.weight]), mask_decoder_tokens_extender.regular_tokens.weight, ) @no_grad() def test_compress_vit_feat(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype) sam_h.set_context(context="hq_sam", value={"early_vit_embedding": early_vit_embedding}) refiners_output = sam_h.ensure_find(CompressViTFeat)() reference_output = reference_sam_h.mask_decoder.compress_vit_feat(early_vit_embedding.permute(0, 3, 1, 2)) assert torch.equal(refiners_output, reference_output) @no_grad() def test_embedding_encoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() x = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype) sam_h.set_context(context="mask_decoder", value={"image_embedding": x}) refiners_output = sam_h.ensure_find(EmbeddingEncoder)() reference_output = reference_sam_h.mask_decoder.embedding_encoder(x) assert torch.equal(refiners_output, reference_output) @no_grad() def test_hq_token_mlp(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() x = torch.randn(1, 6, 256, device=sam_h.device, dtype=sam_h.dtype) refiners_output = sam_h.ensure_find(HQTokenMLP)(x) reference_output = reference_sam_h.mask_decoder.hf_mlp(x[:, -1, :]).unsqueeze(0) assert torch.equal(refiners_output, reference_output) @pytest.mark.parametrize("hq_mask_only", [True, False]) def test_predictor( sam_h: SegmentAnythingH, hq_adapter_weights: Path, hq_mask_only: bool, reference_sam_h_predictor: FacebookSAMPredictorHQ, tennis: Image.Image, one_prompt: SAMPrompt, ) -> None: adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() adapter.hq_mask_only = hq_mask_only assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only # Refiners high_res_masks, iou_predictions, low_res_masks = sam_h.predict(tennis, **one_prompt.__dict__) refiners_high_res_mask_hq = high_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu() refiners_low_res_mask_hq = low_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu() iou_predictions = iou_predictions[0, :].to(dtype=torch.float32).detach().cpu() # Reference reference_sam_h_predictor.set_image(np.array(tennis)) predictor_prompt = one_prompt.__dict__["box_points"] masks_np, iou_predictions_np, low_res_masks_np = reference_sam_h_predictor.predict( box=np.array(predictor_prompt).flatten(), multimask_output=False, hq_token_only=hq_mask_only, ) reference_high_res_mask_hq = torch.from_numpy(masks_np[0, ...]).to(dtype=torch.float32) # type: ignore reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore iou_predictions_np = torch.from_numpy(iou_predictions_np).to(dtype=torch.float32) # type: ignore # NOTE: Diff on logits is relatively high, but on the same scale / even lower than base SAM logits diff (6e-3) # See https://github.com/finegrain-ai/refiners/blob/c6b5eb24a179d48e4542d94684a70c5ef3142ab1/tests/foundationals/segment_anything/test_sam.py#L426 assert torch.allclose( reference_low_res_mask_hq, refiners_low_res_mask_hq, atol=4e-3, ) assert ( torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 1 ) # The diff on the logits above leads to an absolute diff of 1 pixel on the high res masks assert torch.allclose( iou_predictions_np, torch.max(iou_predictions), atol=1e-5, ) @no_grad() def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None: HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() batch_size = 5 image_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1) mask_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1) dense_positional_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat( batch_size, 1, 1, 1 ) point_embedding = torch.randn(1, 2, 256, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1) early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype).repeat( batch_size, 1, 1, 1 ) sam_h.mask_decoder.set_image_embedding(image_embedding) sam_h.mask_decoder.set_mask_embedding(mask_embedding) sam_h.mask_decoder.set_point_embedding(point_embedding) sam_h.mask_decoder.set_dense_positional_embedding(dense_positional_embedding) sam_h.mask_decoder.set_context( context="hq_sam", value={"early_vit_embedding": early_vit_embedding.to(sam_h.device, sam_h.dtype)} ) mask_prediction, iou_prediction = sam_h.mask_decoder() assert mask_prediction.shape == (batch_size, 1, 256, 256) assert iou_prediction.shape == (batch_size, 1) assert torch.equal(mask_prediction[0], mask_prediction[1])