SAM e2e test tolerance explained

This commit is contained in:
Pierre Colle 2024-03-22 14:22:37 +00:00 committed by Colle
parent 364e196874
commit 2763db960e
2 changed files with 40 additions and 2 deletions

View file

@ -21,7 +21,7 @@ from refiners.fluxion import manual_seed
from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
from refiners.foundationals.segment_anything.model import SegmentAnythingH from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
# See predictor_example.ipynb official notebook # See predictor_example.ipynb official notebook
@ -409,7 +409,7 @@ def test_predictor_single_output(
assert torch.allclose( assert torch.allclose(
low_res_masks[0, 0, ...], low_res_masks[0, 0, ...],
torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device), torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device),
atol=6e-3, # TODO: This diff on logits is high, and requires deeper investigation atol=6e-3, # see test_predictor_resized_single_output for more explanation
) )
assert isclose(scores[0].item(), facebook_scores[0].item(), abs_tol=1e-05) assert isclose(scores[0].item(), facebook_scores[0].item(), abs_tol=1e-05)
@ -418,6 +418,43 @@ def test_predictor_single_output(
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05) assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
def test_predictor_resized_single_output(
facebook_sam_h_predictor: FacebookSAMPredictor,
sam_h_single_output: SegmentAnythingH,
truck: Image.Image,
one_prompt: SAMPrompt,
) -> None:
# The refiners implementation of SAM differs from official
# implementation by a 6e-3 absolute diff (see test_predictor_single_output)
# This diff is related to 2 components :
# * image_encoder (see test_image_encoder)
# * point rescaling (facebook uses numpy while refiners uses torch)
#
# Current test is designed to workaround those 2 components
# * facebook image_embedding is used
# * the image is pre-resized by (1024, 1024) so there is no rescaling
# Then the test pass with torch.equal
predictor = facebook_sam_h_predictor
size = (1024, 1024)
resized_truck = truck.resize(size)
predictor.set_image(np.array(resized_truck))
_, _, facebook_low_res_masks = predictor.predict( # type: ignore
**one_prompt.facebook_predict_kwargs(), # type: ignore
multimask_output=False,
)
facebook_image_embedding = ImageEmbedding(features=predictor.features, original_image_size=size)
_, _, low_res_masks = sam_h_single_output.predict(facebook_image_embedding, **one_prompt.__dict__)
assert torch.equal(
low_res_masks[0, 0, ...],
torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device),
)
def test_mask_encoder( def test_mask_encoder(
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
) -> None: ) -> None:

View file

@ -40,6 +40,7 @@ class FacebookSAM(nn.Module):
class FacebookSAMPredictor: class FacebookSAMPredictor:
model: FacebookSAM model: FacebookSAM
features: Tensor
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ... def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...