From cba83b0558e11574efd6788412c46f9d3626f991 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Fri, 22 Mar 2024 21:50:27 +0000 Subject: [PATCH] SAM init with mask_decoder after #325 --- .../foundationals/segment_anything/model.py | 4 ++-- tests/foundationals/segment_anything/test_sam.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py index c178bd0..7a50ddf 100644 --- a/src/refiners/foundationals/segment_anything/model.py +++ b/src/refiners/foundationals/segment_anything/model.py @@ -268,8 +268,8 @@ class SegmentAnythingH(SegmentAnything): if mask_decoder: assert ( - mask_decoder.multimask_output == multimask_output - ), f"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output (${multimask_output})" + multimask_output is None or mask_decoder.multimask_output == multimask_output + ), f"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output ({multimask_output})" else: mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder() diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index f521f14..df37e30 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -21,6 +21,7 @@ from refiners.fluxion import manual_seed from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention +from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer @@ -124,6 +125,19 @@ def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None: assert torch.equal(input=y_1, other=y_2) +def test_mask_decoder_arg() -> None: + mask_decoder_default = MaskDecoder() + sam_h = SegmentAnythingH(mask_decoder=mask_decoder_default) + + assert sam_h.mask_decoder == mask_decoder_default + + +def test_multimask_output_error() -> None: + mask_decoder_multimask_output = MaskDecoder(multimask_output=True) + with pytest.raises(AssertionError, match="multimask_output"): + SegmentAnythingH(mask_decoder=mask_decoder_multimask_output, multimask_output=False) + + @no_grad() def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None: image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device)