mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
Add logits comparison for base SAM in single mask output prediction mode
This commit is contained in:
parent
38c86f59f4
commit
c6b5eb24a1
|
@ -407,23 +407,29 @@ def test_predictor_single_output(
|
||||||
predictor = facebook_sam_h_predictor
|
predictor = facebook_sam_h_predictor
|
||||||
predictor.set_image(np.array(truck))
|
predictor.set_image(np.array(truck))
|
||||||
|
|
||||||
facebook_masks, facebook_scores, _ = predictor.predict( # type: ignore
|
facebook_masks, facebook_scores, facebook_low_res_masks = predictor.predict( # type: ignore
|
||||||
**one_prompt.facebook_predict_kwargs(), # type: ignore
|
**one_prompt.facebook_predict_kwargs(), # type: ignore
|
||||||
multimask_output=False,
|
multimask_output=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(facebook_masks) == 1
|
assert len(facebook_masks) == 1
|
||||||
|
|
||||||
masks, scores, _ = sam_h_single_output.predict(truck, **one_prompt.__dict__)
|
masks, scores, low_res_masks = sam_h_single_output.predict(truck, **one_prompt.__dict__)
|
||||||
masks = masks.squeeze(0)
|
masks = masks.squeeze(0)
|
||||||
scores = scores.squeeze(0)
|
scores = scores.squeeze(0)
|
||||||
|
|
||||||
assert len(masks) == 1
|
assert len(masks) == 1
|
||||||
|
|
||||||
|
assert torch.allclose(
|
||||||
|
low_res_masks[0, 0, ...],
|
||||||
|
torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device),
|
||||||
|
atol=6e-3, # TODO: This diff on logits is high, and requires deeper investigation
|
||||||
|
)
|
||||||
|
assert isclose(scores[0].item(), facebook_scores[0].item(), abs_tol=1e-05)
|
||||||
|
|
||||||
mask_prediction = masks[0].cpu()
|
mask_prediction = masks[0].cpu()
|
||||||
facebook_mask = torch.as_tensor(facebook_masks[0])
|
facebook_mask = torch.as_tensor(facebook_masks[0])
|
||||||
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
|
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
|
||||||
assert isclose(scores[0].item(), facebook_scores[0].item(), rel_tol=1e-05)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mask_encoder(
|
def test_mask_encoder(
|
||||||
|
|
Loading…
Reference in a new issue