mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
SAM e2e test tolerance explained
This commit is contained in:
parent
364e196874
commit
2763db960e
|
@ -21,7 +21,7 @@ from refiners.fluxion import manual_seed
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
|
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
|
||||||
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
|
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
|
||||||
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH
|
||||||
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
|
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
|
||||||
|
|
||||||
# See predictor_example.ipynb official notebook
|
# See predictor_example.ipynb official notebook
|
||||||
|
@ -409,7 +409,7 @@ def test_predictor_single_output(
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
low_res_masks[0, 0, ...],
|
low_res_masks[0, 0, ...],
|
||||||
torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device),
|
torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device),
|
||||||
atol=6e-3, # TODO: This diff on logits is high, and requires deeper investigation
|
atol=6e-3, # see test_predictor_resized_single_output for more explanation
|
||||||
)
|
)
|
||||||
assert isclose(scores[0].item(), facebook_scores[0].item(), abs_tol=1e-05)
|
assert isclose(scores[0].item(), facebook_scores[0].item(), abs_tol=1e-05)
|
||||||
|
|
||||||
|
@ -418,6 +418,43 @@ def test_predictor_single_output(
|
||||||
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
|
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
|
||||||
|
|
||||||
|
|
||||||
|
def test_predictor_resized_single_output(
|
||||||
|
facebook_sam_h_predictor: FacebookSAMPredictor,
|
||||||
|
sam_h_single_output: SegmentAnythingH,
|
||||||
|
truck: Image.Image,
|
||||||
|
one_prompt: SAMPrompt,
|
||||||
|
) -> None:
|
||||||
|
# The refiners implementation of SAM differs from official
|
||||||
|
# implementation by a 6e-3 absolute diff (see test_predictor_single_output)
|
||||||
|
# This diff is related to 2 components :
|
||||||
|
# * image_encoder (see test_image_encoder)
|
||||||
|
# * point rescaling (facebook uses numpy while refiners uses torch)
|
||||||
|
#
|
||||||
|
# Current test is designed to workaround those 2 components
|
||||||
|
# * facebook image_embedding is used
|
||||||
|
# * the image is pre-resized by (1024, 1024) so there is no rescaling
|
||||||
|
# Then the test pass with torch.equal
|
||||||
|
|
||||||
|
predictor = facebook_sam_h_predictor
|
||||||
|
size = (1024, 1024)
|
||||||
|
resized_truck = truck.resize(size)
|
||||||
|
predictor.set_image(np.array(resized_truck))
|
||||||
|
|
||||||
|
_, _, facebook_low_res_masks = predictor.predict( # type: ignore
|
||||||
|
**one_prompt.facebook_predict_kwargs(), # type: ignore
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
facebook_image_embedding = ImageEmbedding(features=predictor.features, original_image_size=size)
|
||||||
|
|
||||||
|
_, _, low_res_masks = sam_h_single_output.predict(facebook_image_embedding, **one_prompt.__dict__)
|
||||||
|
|
||||||
|
assert torch.equal(
|
||||||
|
low_res_masks[0, 0, ...],
|
||||||
|
torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_mask_encoder(
|
def test_mask_encoder(
|
||||||
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
|
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -40,6 +40,7 @@ class FacebookSAM(nn.Module):
|
||||||
|
|
||||||
class FacebookSAMPredictor:
|
class FacebookSAMPredictor:
|
||||||
model: FacebookSAM
|
model: FacebookSAM
|
||||||
|
features: Tensor
|
||||||
|
|
||||||
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...
|
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue