From fdeeb254353d3bcb41645bd00280211f37dcd84f Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Tue, 19 Mar 2024 16:08:54 +0000 Subject: [PATCH] Add multimask_output flag to SAM --- .../segment_anything/mask_decoder.py | 48 ++++++++++++------- .../foundationals/segment_anything/model.py | 10 +++- .../segment_anything/test_sam.py | 35 ++++++++++++++ 3 files changed, 74 insertions(+), 19 deletions(-) diff --git a/src/refiners/foundationals/segment_anything/mask_decoder.py b/src/refiners/foundationals/segment_anything/mask_decoder.py index 0f9c477..0827dd0 100644 --- a/src/refiners/foundationals/segment_anything/mask_decoder.py +++ b/src/refiners/foundationals/segment_anything/mask_decoder.py @@ -10,10 +10,6 @@ from refiners.foundationals.segment_anything.transformer import ( class EmbeddingsAggregator(fl.ContextModule): - def __init__(self, num_output_mask: int = 3) -> None: - super().__init__() - self.num_mask_tokens = num_output_mask - def forward(self, iou_mask_tokens: Tensor) -> Tensor: mask_decoder = self.ensure_parent mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder") @@ -48,7 +44,7 @@ class Hypernetworks(fl.Concatenate): self, embedding_dim: int = 256, num_layers: int = 3, - num_mask_tokens: int = 3, + num_mask_tokens: int = 4, device: Device | str | None = None, dtype: DType | None = None, ) -> None: @@ -70,7 +66,7 @@ class Hypernetworks(fl.Concatenate): dtype=dtype, ), ) - for i in range(num_mask_tokens + 1) + for i in range(num_mask_tokens) ], dim=1, ) @@ -138,6 +134,7 @@ class MaskPrediction(fl.Chain): self, embedding_dim: int, num_mask_tokens: int, + multimask_output: bool, num_layers: int = 3, device: Device | str | None = None, dtype: DType | None = None, @@ -145,6 +142,10 @@ class MaskPrediction(fl.Chain): self.embedding_dim = embedding_dim self.num_mask_tokens = num_mask_tokens self.num_layers = num_layers + self.multimask_output = multimask_output + + start_mask, num_masks = (1, num_mask_tokens - 1) if multimask_output else (0, 1) + super().__init__( fl.Matmul( input=Hypernetworks( @@ -156,8 +157,8 @@ class MaskPrediction(fl.Chain): ), other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype), ), - fl.Slicing(dim=1, start=1), - fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim), + fl.Slicing(dim=1, start=start_mask, end=start_mask + num_masks), + fl.Reshape(num_masks, embedding_dim, embedding_dim), ) @@ -167,47 +168,53 @@ class IOUPrediction(fl.Chain): embedding_dim: int, num_layers: int, num_mask_tokens: int, + multimask_output: bool, device: Device | str | None = None, dtype: DType | None = None, ) -> None: self.embedding_dim = embedding_dim self.num_layers = num_layers + self.multimask_output = multimask_output + super().__init__( fl.Slicing(dim=1, start=0, end=1), fl.Squeeze(dim=0), fl.MultiLinear( input_dim=embedding_dim, - output_dim=num_mask_tokens + 1, + output_dim=num_mask_tokens, inner_dim=embedding_dim, num_layers=num_layers, device=device, dtype=dtype, ), - fl.Slicing(dim=-1, start=1), + fl.Slicing(dim=-1, start=1) if multimask_output else fl.Slicing(dim=-1, start=0, end=1), ) class MaskDecoder(fl.Chain): def __init__( self, + multimask_output: bool = True, embedding_dim: int = 256, feed_forward_dim: int = 2048, num_layers: int = 2, - num_output_mask: int = 3, + num_multimask_outputs: int = 3, device: Device | str | None = None, dtype: DType | None = None, ) -> None: super().__init__() + self.multimask_output = multimask_output self.embedding_dim = embedding_dim - self.num_mask_tokens = num_output_mask self.feed_forward_dim = feed_forward_dim self.num_layers = num_layers + self.num_multimask_outputs = num_multimask_outputs + + # The 1 additional token is for single-output mask prediction + num_mask_tokens = self.num_multimask_outputs + 1 super().__init__( - IOUMaskEncoder( - embedding_dim=embedding_dim, num_mask_tokens=num_output_mask + 1, device=device, dtype=dtype - ), - EmbeddingsAggregator(num_output_mask=num_output_mask), + IOUMaskEncoder(embedding_dim=embedding_dim, num_mask_tokens=num_mask_tokens, device=device, dtype=dtype), + EmbeddingsAggregator(), Transformer( *( TwoWayTransformerLayer( @@ -225,12 +232,17 @@ class MaskDecoder(fl.Chain): ), fl.Parallel( MaskPrediction( - embedding_dim=embedding_dim, num_mask_tokens=num_output_mask, device=device, dtype=dtype + embedding_dim=embedding_dim, + num_mask_tokens=num_mask_tokens, + multimask_output=multimask_output, + device=device, + dtype=dtype, ), IOUPrediction( embedding_dim=embedding_dim, num_layers=3, - num_mask_tokens=num_output_mask, + num_mask_tokens=num_mask_tokens, + multimask_output=multimask_output, device=device, dtype=dtype, ), diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py index 655c3c2..4e3b03d 100644 --- a/src/refiners/foundationals/segment_anything/model.py +++ b/src/refiners/foundationals/segment_anything/model.py @@ -233,6 +233,7 @@ class SegmentAnythingH(SegmentAnything): point_encoder: PointEncoder | None = None, mask_encoder: MaskEncoder | None = None, mask_decoder: MaskDecoder | None = None, + multimask_output: bool | None = None, device: Device | str = "cpu", dtype: DType = torch.float32, ) -> None: @@ -243,13 +244,20 @@ class SegmentAnythingH(SegmentAnything): point_encoder: The point encoder to use. mask_encoder: The mask encoder to use. mask_decoder: The mask decoder to use. + multimask_output: Whether to use multimask output. device: The PyTorch device to use. dtype: The PyTorch data type to use. """ image_encoder = image_encoder or SAMViTH() point_encoder = point_encoder or PointEncoder() mask_encoder = mask_encoder or MaskEncoder() - mask_decoder = mask_decoder or MaskDecoder() + + if mask_decoder: + assert ( + mask_decoder.multimask_output == multimask_output + ), f"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output (${multimask_output})" + else: + mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder() super().__init__( image_encoder=image_encoder, diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index e8afa41..31e0eec 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -90,6 +90,13 @@ def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH: return sam_h +@pytest.fixture(scope="module") +def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH: + sam_h = SegmentAnythingH(multimask_output=False, device=test_device) + sam_h.load_from_safetensors(tensors_path=sam_h_weights) + return sam_h + + @pytest.fixture(scope="module") def ref_path(test_sam_path: Path) -> Path: return test_sam_path / "test_sam_ref" @@ -391,6 +398,34 @@ def test_predictor_dense_mask( assert isclose(intersection_over_union(dense_mask_prediction, facebook_dense_mask), 1.0, rel_tol=5e-05) +def test_predictor_single_output( + facebook_sam_h_predictor: FacebookSAMPredictor, + sam_h_single_output: SegmentAnythingH, + truck: Image.Image, + one_prompt: SAMPrompt, +) -> None: + predictor = facebook_sam_h_predictor + predictor.set_image(np.array(truck)) + + facebook_masks, facebook_scores, _ = predictor.predict( # type: ignore + **one_prompt.facebook_predict_kwargs(), # type: ignore + multimask_output=False, + ) + + assert len(facebook_masks) == 1 + + masks, scores, _ = sam_h_single_output.predict(truck, **one_prompt.__dict__) + masks = masks.squeeze(0) + scores = scores.squeeze(0) + + assert len(masks) == 1 + + mask_prediction = masks[0].cpu() + facebook_mask = torch.as_tensor(facebook_masks[0]) + assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05) + assert isclose(scores[0].item(), facebook_scores[0].item(), rel_tol=1e-05) + + def test_mask_encoder( facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt ) -> None: