Add logits comparison for base SAM in single mask output prediction mode

This commit is contained in:
hugojarkoff 2024-03-21 09:20:04 +00:00 committed by hugojarkoff
parent 38c86f59f4
commit c6b5eb24a1

View file

@ -407,23 +407,29 @@ def test_predictor_single_output(
predictor = facebook_sam_h_predictor
predictor.set_image(np.array(truck))
facebook_masks, facebook_scores, _ = predictor.predict( # type: ignore
facebook_masks, facebook_scores, facebook_low_res_masks = predictor.predict( # type: ignore
**one_prompt.facebook_predict_kwargs(), # type: ignore
multimask_output=False,
)
assert len(facebook_masks) == 1
masks, scores, _ = sam_h_single_output.predict(truck, **one_prompt.__dict__)
masks, scores, low_res_masks = sam_h_single_output.predict(truck, **one_prompt.__dict__)
masks = masks.squeeze(0)
scores = scores.squeeze(0)
assert len(masks) == 1
assert torch.allclose(
low_res_masks[0, 0, ...],
torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device),
atol=6e-3, # TODO: This diff on logits is high, and requires deeper investigation
)
assert isclose(scores[0].item(), facebook_scores[0].item(), abs_tol=1e-05)
mask_prediction = masks[0].cpu()
facebook_mask = torch.as_tensor(facebook_masks[0])
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
assert isclose(scores[0].item(), facebook_scores[0].item(), rel_tol=1e-05)
def test_mask_encoder(