From 5c937b184a971f76d46a16b95b1b31c4be492729 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Fri, 22 Mar 2024 21:30:46 +0000 Subject: [PATCH] HQ-SAM logit equal test, following #331 --- pyproject.toml | 2 +- .../segment_anything/test_hq_sam.py | 58 ++++++++++++++++++- tests/foundationals/segment_anything/utils.py | 2 + 3 files changed, 58 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b6bb35a..901a6b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,7 +144,7 @@ exclude_also = [ [tool.typos.default] extend-words = { adaptee = "adaptee" } -extend-ignore-identifiers-re = ["NDArray*"] +extend-ignore-identifiers-re = ["NDArray*", "interm"] [tool.pytest.ini_options] filterwarnings = [ diff --git a/tests/foundationals/segment_anything/test_hq_sam.py b/tests/foundationals/segment_anything/test_hq_sam.py index 3d79eaf..8615119 100644 --- a/tests/foundationals/segment_anything/test_hq_sam.py +++ b/tests/foundationals/segment_anything/test_hq_sam.py @@ -23,7 +23,7 @@ from refiners.foundationals.segment_anything.hq_sam import ( MaskDecoderTokensExtender, PredictionsPostProc, ) -from refiners.foundationals.segment_anything.model import SegmentAnythingH +from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH @pytest.fixture(scope="module") @@ -248,8 +248,8 @@ def test_predictor( reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore iou_predictions_np = torch.from_numpy(iou_predictions_np).to(dtype=torch.float32) # type: ignore - # NOTE: Diff on logits is relatively high, but on the same scale / even lower than base SAM logits diff (6e-3) - # See https://github.com/finegrain-ai/refiners/blob/c6b5eb24a179d48e4542d94684a70c5ef3142ab1/tests/foundationals/segment_anything/test_sam.py#L426 + # NOTE: Diff on logits is relatively high, + # see test_predictor_equal for a stricter version assert torch.allclose( reference_low_res_mask_hq, refiners_low_res_mask_hq, @@ -265,6 +265,58 @@ def test_predictor( ) +@pytest.mark.parametrize("hq_mask_only", [True, False]) +def test_predictor_equal( + sam_h: SegmentAnythingH, + hq_adapter_weights: Path, + hq_mask_only: bool, + reference_sam_h_predictor: FacebookSAMPredictorHQ, + tennis: Image.Image, + one_prompt: SAMPrompt, +) -> None: + adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() + + adapter.hq_mask_only = hq_mask_only + assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only + + # See in test_sam.py test_predictor_resized_single_output + # to do torch.equal we need to resize the image before + # and to use image_embedding as input + + size = (1024, 1024) + resized_tennis = tennis.resize(size) + + # Reference + reference_sam_h_predictor.set_image(np.array(resized_tennis)) + + predictor_prompt = one_prompt.__dict__["box_points"] + masks_np, _, low_res_masks_np = reference_sam_h_predictor.predict( + box=np.array(predictor_prompt).flatten(), + multimask_output=False, + hq_token_only=hq_mask_only, + ) + + reference_high_res_mask_hq = torch.from_numpy(masks_np[0, ...]).to(dtype=torch.float32) # type: ignore + reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore + + # Refiners + + # We bypass the refiners ViT by using directly the image features and interm_features + # from the reference implementation: this gives the ability to do torch.equal + reference_image_embedding = ImageEmbedding(features=reference_sam_h_predictor.features, original_image_size=size) + adapter.set_context("hq_sam", {"early_vit_embedding": reference_sam_h_predictor.interm_features[0]}) + + high_res_masks, _, low_res_masks = sam_h.predict(reference_image_embedding, **one_prompt.__dict__) + refiners_high_res_mask_hq = high_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu() + refiners_low_res_mask_hq = low_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu() + + assert torch.equal( + reference_low_res_mask_hq, + refiners_low_res_mask_hq, + ) + assert torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() == 0 + + @no_grad() def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None: HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() diff --git a/tests/foundationals/segment_anything/utils.py b/tests/foundationals/segment_anything/utils.py index e98fe6d..aac97f9 100644 --- a/tests/foundationals/segment_anything/utils.py +++ b/tests/foundationals/segment_anything/utils.py @@ -57,6 +57,8 @@ class FacebookSAMPredictor: class FacebookSAMPredictorHQ: model: FacebookSAM + features: Tensor + interm_features: Tensor def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...