Add multimask_output flag to SAM

This commit is contained in:
Pierre Colle 2024-03-19 16:08:54 +00:00 committed by Colle
parent 6a72943ff7
commit 68fe725767
3 changed files with 74 additions and 19 deletions

View file

@ -10,10 +10,6 @@ from refiners.foundationals.segment_anything.transformer import (
class EmbeddingsAggregator(fl.ContextModule): class EmbeddingsAggregator(fl.ContextModule):
def __init__(self, num_output_mask: int = 3) -> None:
super().__init__()
self.num_mask_tokens = num_output_mask
def forward(self, iou_mask_tokens: Tensor) -> Tensor: def forward(self, iou_mask_tokens: Tensor) -> Tensor:
mask_decoder = self.ensure_parent mask_decoder = self.ensure_parent
mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder") mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder")
@ -48,7 +44,7 @@ class Hypernetworks(fl.Concatenate):
self, self,
embedding_dim: int = 256, embedding_dim: int = 256,
num_layers: int = 3, num_layers: int = 3,
num_mask_tokens: int = 3, num_mask_tokens: int = 4,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
@ -70,7 +66,7 @@ class Hypernetworks(fl.Concatenate):
dtype=dtype, dtype=dtype,
), ),
) )
for i in range(num_mask_tokens + 1) for i in range(num_mask_tokens)
], ],
dim=1, dim=1,
) )
@ -138,6 +134,7 @@ class MaskPrediction(fl.Chain):
self, self,
embedding_dim: int, embedding_dim: int,
num_mask_tokens: int, num_mask_tokens: int,
multimask_output: bool,
num_layers: int = 3, num_layers: int = 3,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
@ -145,6 +142,10 @@ class MaskPrediction(fl.Chain):
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.num_mask_tokens = num_mask_tokens self.num_mask_tokens = num_mask_tokens
self.num_layers = num_layers self.num_layers = num_layers
self.multimask_output = multimask_output
start_mask, num_masks = (1, num_mask_tokens - 1) if multimask_output else (0, 1)
super().__init__( super().__init__(
fl.Matmul( fl.Matmul(
input=Hypernetworks( input=Hypernetworks(
@ -156,8 +157,8 @@ class MaskPrediction(fl.Chain):
), ),
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype), other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
), ),
fl.Slicing(dim=1, start=1), fl.Slicing(dim=1, start=start_mask, end=start_mask + num_masks),
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim), fl.Reshape(num_masks, embedding_dim, embedding_dim),
) )
@ -167,47 +168,53 @@ class IOUPrediction(fl.Chain):
embedding_dim: int, embedding_dim: int,
num_layers: int, num_layers: int,
num_mask_tokens: int, num_mask_tokens: int,
multimask_output: bool,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.num_layers = num_layers self.num_layers = num_layers
self.multimask_output = multimask_output
super().__init__( super().__init__(
fl.Slicing(dim=1, start=0, end=1), fl.Slicing(dim=1, start=0, end=1),
fl.Squeeze(dim=0), fl.Squeeze(dim=0),
fl.MultiLinear( fl.MultiLinear(
input_dim=embedding_dim, input_dim=embedding_dim,
output_dim=num_mask_tokens + 1, output_dim=num_mask_tokens,
inner_dim=embedding_dim, inner_dim=embedding_dim,
num_layers=num_layers, num_layers=num_layers,
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),
fl.Slicing(dim=-1, start=1), fl.Slicing(dim=-1, start=1) if multimask_output else fl.Slicing(dim=-1, start=0, end=1),
) )
class MaskDecoder(fl.Chain): class MaskDecoder(fl.Chain):
def __init__( def __init__(
self, self,
multimask_output: bool = True,
embedding_dim: int = 256, embedding_dim: int = 256,
feed_forward_dim: int = 2048, feed_forward_dim: int = 2048,
num_layers: int = 2, num_layers: int = 2,
num_output_mask: int = 3, num_multimask_outputs: int = 3,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.multimask_output = multimask_output
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.num_mask_tokens = num_output_mask
self.feed_forward_dim = feed_forward_dim self.feed_forward_dim = feed_forward_dim
self.num_layers = num_layers self.num_layers = num_layers
self.num_multimask_outputs = num_multimask_outputs
# The 1 additional token is for single-output mask prediction
num_mask_tokens = self.num_multimask_outputs + 1
super().__init__( super().__init__(
IOUMaskEncoder( IOUMaskEncoder(embedding_dim=embedding_dim, num_mask_tokens=num_mask_tokens, device=device, dtype=dtype),
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask + 1, device=device, dtype=dtype EmbeddingsAggregator(),
),
EmbeddingsAggregator(num_output_mask=num_output_mask),
Transformer( Transformer(
*( *(
TwoWayTransformerLayer( TwoWayTransformerLayer(
@ -225,12 +232,17 @@ class MaskDecoder(fl.Chain):
), ),
fl.Parallel( fl.Parallel(
MaskPrediction( MaskPrediction(
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask, device=device, dtype=dtype embedding_dim=embedding_dim,
num_mask_tokens=num_mask_tokens,
multimask_output=multimask_output,
device=device,
dtype=dtype,
), ),
IOUPrediction( IOUPrediction(
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
num_layers=3, num_layers=3,
num_mask_tokens=num_output_mask, num_mask_tokens=num_mask_tokens,
multimask_output=multimask_output,
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),

View file

@ -233,6 +233,7 @@ class SegmentAnythingH(SegmentAnything):
point_encoder: PointEncoder | None = None, point_encoder: PointEncoder | None = None,
mask_encoder: MaskEncoder | None = None, mask_encoder: MaskEncoder | None = None,
mask_decoder: MaskDecoder | None = None, mask_decoder: MaskDecoder | None = None,
multimask_output: bool | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = torch.float32, dtype: DType = torch.float32,
) -> None: ) -> None:
@ -243,13 +244,20 @@ class SegmentAnythingH(SegmentAnything):
point_encoder: The point encoder to use. point_encoder: The point encoder to use.
mask_encoder: The mask encoder to use. mask_encoder: The mask encoder to use.
mask_decoder: The mask decoder to use. mask_decoder: The mask decoder to use.
multimask_output: Whether to use multimask output.
device: The PyTorch device to use. device: The PyTorch device to use.
dtype: The PyTorch data type to use. dtype: The PyTorch data type to use.
""" """
image_encoder = image_encoder or SAMViTH() image_encoder = image_encoder or SAMViTH()
point_encoder = point_encoder or PointEncoder() point_encoder = point_encoder or PointEncoder()
mask_encoder = mask_encoder or MaskEncoder() mask_encoder = mask_encoder or MaskEncoder()
mask_decoder = mask_decoder or MaskDecoder()
if mask_decoder:
assert (
mask_decoder.multimask_output == multimask_output
), f"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output (${multimask_output})"
else:
mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()
super().__init__( super().__init__(
image_encoder=image_encoder, image_encoder=image_encoder,

View file

@ -90,6 +90,13 @@ def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
return sam_h return sam_h
@pytest.fixture(scope="module")
def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
return sam_h
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ref_path(test_sam_path: Path) -> Path: def ref_path(test_sam_path: Path) -> Path:
return test_sam_path / "test_sam_ref" return test_sam_path / "test_sam_ref"
@ -391,6 +398,34 @@ def test_predictor_dense_mask(
assert isclose(intersection_over_union(dense_mask_prediction, facebook_dense_mask), 1.0, rel_tol=5e-05) assert isclose(intersection_over_union(dense_mask_prediction, facebook_dense_mask), 1.0, rel_tol=5e-05)
def test_predictor_single_output(
facebook_sam_h_predictor: FacebookSAMPredictor,
sam_h_single_output: SegmentAnythingH,
truck: Image.Image,
one_prompt: SAMPrompt,
) -> None:
predictor = facebook_sam_h_predictor
predictor.set_image(np.array(truck))
facebook_masks, facebook_scores, _ = predictor.predict( # type: ignore
**one_prompt.facebook_predict_kwargs(), # type: ignore
multimask_output=False,
)
assert len(facebook_masks) == 1
masks, scores, _ = sam_h_single_output.predict(truck, **one_prompt.__dict__)
masks = masks.squeeze(0)
scores = scores.squeeze(0)
assert len(masks) == 1
mask_prediction = masks[0].cpu()
facebook_mask = torch.as_tensor(facebook_masks[0])
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
assert isclose(scores[0].item(), facebook_scores[0].item(), rel_tol=1e-05)
def test_mask_encoder( def test_mask_encoder(
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
) -> None: ) -> None: