From c6b5eb24a179d48e4542d94684a70c5ef3142ab1 Mon Sep 17 00:00:00 2001 From: hugojarkoff Date: Thu, 21 Mar 2024 09:20:04 +0000 Subject: [PATCH] Add logits comparison for base SAM in single mask output prediction mode --- tests/foundationals/segment_anything/test_sam.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index 31e0eec..e40eab5 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -407,23 +407,29 @@ def test_predictor_single_output( predictor = facebook_sam_h_predictor predictor.set_image(np.array(truck)) - facebook_masks, facebook_scores, _ = predictor.predict( # type: ignore + facebook_masks, facebook_scores, facebook_low_res_masks = predictor.predict( # type: ignore **one_prompt.facebook_predict_kwargs(), # type: ignore multimask_output=False, ) assert len(facebook_masks) == 1 - masks, scores, _ = sam_h_single_output.predict(truck, **one_prompt.__dict__) + masks, scores, low_res_masks = sam_h_single_output.predict(truck, **one_prompt.__dict__) masks = masks.squeeze(0) scores = scores.squeeze(0) assert len(masks) == 1 + assert torch.allclose( + low_res_masks[0, 0, ...], + torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device), + atol=6e-3, # TODO: This diff on logits is high, and requires deeper investigation + ) + assert isclose(scores[0].item(), facebook_scores[0].item(), abs_tol=1e-05) + mask_prediction = masks[0].cpu() facebook_mask = torch.as_tensor(facebook_masks[0]) assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05) - assert isclose(scores[0].item(), facebook_scores[0].item(), rel_tol=1e-05) def test_mask_encoder(