mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 22:58:45 +00:00
SAM e2e test tolerance explained
This commit is contained in:
parent
364e196874
commit
2763db960e
|
@ -21,7 +21,7 @@ from refiners.fluxion import manual_seed
|
|||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
|
||||
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
|
||||
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
||||
from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH
|
||||
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
|
||||
|
||||
# See predictor_example.ipynb official notebook
|
||||
|
@ -409,7 +409,7 @@ def test_predictor_single_output(
|
|||
assert torch.allclose(
|
||||
low_res_masks[0, 0, ...],
|
||||
torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device),
|
||||
atol=6e-3, # TODO: This diff on logits is high, and requires deeper investigation
|
||||
atol=6e-3, # see test_predictor_resized_single_output for more explanation
|
||||
)
|
||||
assert isclose(scores[0].item(), facebook_scores[0].item(), abs_tol=1e-05)
|
||||
|
||||
|
@ -418,6 +418,43 @@ def test_predictor_single_output(
|
|||
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
|
||||
|
||||
|
||||
def test_predictor_resized_single_output(
|
||||
facebook_sam_h_predictor: FacebookSAMPredictor,
|
||||
sam_h_single_output: SegmentAnythingH,
|
||||
truck: Image.Image,
|
||||
one_prompt: SAMPrompt,
|
||||
) -> None:
|
||||
# The refiners implementation of SAM differs from official
|
||||
# implementation by a 6e-3 absolute diff (see test_predictor_single_output)
|
||||
# This diff is related to 2 components :
|
||||
# * image_encoder (see test_image_encoder)
|
||||
# * point rescaling (facebook uses numpy while refiners uses torch)
|
||||
#
|
||||
# Current test is designed to workaround those 2 components
|
||||
# * facebook image_embedding is used
|
||||
# * the image is pre-resized by (1024, 1024) so there is no rescaling
|
||||
# Then the test pass with torch.equal
|
||||
|
||||
predictor = facebook_sam_h_predictor
|
||||
size = (1024, 1024)
|
||||
resized_truck = truck.resize(size)
|
||||
predictor.set_image(np.array(resized_truck))
|
||||
|
||||
_, _, facebook_low_res_masks = predictor.predict( # type: ignore
|
||||
**one_prompt.facebook_predict_kwargs(), # type: ignore
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
facebook_image_embedding = ImageEmbedding(features=predictor.features, original_image_size=size)
|
||||
|
||||
_, _, low_res_masks = sam_h_single_output.predict(facebook_image_embedding, **one_prompt.__dict__)
|
||||
|
||||
assert torch.equal(
|
||||
low_res_masks[0, 0, ...],
|
||||
torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device),
|
||||
)
|
||||
|
||||
|
||||
def test_mask_encoder(
|
||||
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
|
||||
) -> None:
|
||||
|
|
|
@ -40,6 +40,7 @@ class FacebookSAM(nn.Module):
|
|||
|
||||
class FacebookSAMPredictor:
|
||||
model: FacebookSAM
|
||||
features: Tensor
|
||||
|
||||
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...
|
||||
|
||||
|
|
Loading…
Reference in a new issue