mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
Add logits comparison for base SAM in single mask output prediction mode
This commit is contained in:
parent
38c86f59f4
commit
c6b5eb24a1
|
@ -407,23 +407,29 @@ def test_predictor_single_output(
|
|||
predictor = facebook_sam_h_predictor
|
||||
predictor.set_image(np.array(truck))
|
||||
|
||||
facebook_masks, facebook_scores, _ = predictor.predict( # type: ignore
|
||||
facebook_masks, facebook_scores, facebook_low_res_masks = predictor.predict( # type: ignore
|
||||
**one_prompt.facebook_predict_kwargs(), # type: ignore
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
assert len(facebook_masks) == 1
|
||||
|
||||
masks, scores, _ = sam_h_single_output.predict(truck, **one_prompt.__dict__)
|
||||
masks, scores, low_res_masks = sam_h_single_output.predict(truck, **one_prompt.__dict__)
|
||||
masks = masks.squeeze(0)
|
||||
scores = scores.squeeze(0)
|
||||
|
||||
assert len(masks) == 1
|
||||
|
||||
assert torch.allclose(
|
||||
low_res_masks[0, 0, ...],
|
||||
torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device),
|
||||
atol=6e-3, # TODO: This diff on logits is high, and requires deeper investigation
|
||||
)
|
||||
assert isclose(scores[0].item(), facebook_scores[0].item(), abs_tol=1e-05)
|
||||
|
||||
mask_prediction = masks[0].cpu()
|
||||
facebook_mask = torch.as_tensor(facebook_masks[0])
|
||||
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
|
||||
assert isclose(scores[0].item(), facebook_scores[0].item(), rel_tol=1e-05)
|
||||
|
||||
|
||||
def test_mask_encoder(
|
||||
|
|
Loading…
Reference in a new issue