From 2763db960efbf9d887df2d5c099b2e8d48c9757c Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Fri, 22 Mar 2024 14:22:37 +0000 Subject: [PATCH] SAM e2e test tolerance explained --- .../segment_anything/test_sam.py | 41 ++++++++++++++++++- tests/foundationals/segment_anything/utils.py | 1 + 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index 9a8f4fb..f521f14 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -21,7 +21,7 @@ from refiners.fluxion import manual_seed from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention -from refiners.foundationals.segment_anything.model import SegmentAnythingH +from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer # See predictor_example.ipynb official notebook @@ -409,7 +409,7 @@ def test_predictor_single_output( assert torch.allclose( low_res_masks[0, 0, ...], torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device), - atol=6e-3, # TODO: This diff on logits is high, and requires deeper investigation + atol=6e-3, # see test_predictor_resized_single_output for more explanation ) assert isclose(scores[0].item(), facebook_scores[0].item(), abs_tol=1e-05) @@ -418,6 +418,43 @@ def test_predictor_single_output( assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05) +def test_predictor_resized_single_output( + facebook_sam_h_predictor: FacebookSAMPredictor, + sam_h_single_output: SegmentAnythingH, + truck: Image.Image, + one_prompt: SAMPrompt, +) -> None: + # The refiners implementation of SAM differs from official + # implementation by a 6e-3 absolute diff (see test_predictor_single_output) + # This diff is related to 2 components : + # * image_encoder (see test_image_encoder) + # * point rescaling (facebook uses numpy while refiners uses torch) + # + # Current test is designed to workaround those 2 components + # * facebook image_embedding is used + # * the image is pre-resized by (1024, 1024) so there is no rescaling + # Then the test pass with torch.equal + + predictor = facebook_sam_h_predictor + size = (1024, 1024) + resized_truck = truck.resize(size) + predictor.set_image(np.array(resized_truck)) + + _, _, facebook_low_res_masks = predictor.predict( # type: ignore + **one_prompt.facebook_predict_kwargs(), # type: ignore + multimask_output=False, + ) + + facebook_image_embedding = ImageEmbedding(features=predictor.features, original_image_size=size) + + _, _, low_res_masks = sam_h_single_output.predict(facebook_image_embedding, **one_prompt.__dict__) + + assert torch.equal( + low_res_masks[0, 0, ...], + torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device), + ) + + def test_mask_encoder( facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt ) -> None: diff --git a/tests/foundationals/segment_anything/utils.py b/tests/foundationals/segment_anything/utils.py index b397359..e98fe6d 100644 --- a/tests/foundationals/segment_anything/utils.py +++ b/tests/foundationals/segment_anything/utils.py @@ -40,6 +40,7 @@ class FacebookSAM(nn.Module): class FacebookSAMPredictor: model: FacebookSAM + features: Tensor def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...