mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
Add multimask_output flag to SAM
This commit is contained in:
parent
6a72943ff7
commit
68fe725767
|
@ -10,10 +10,6 @@ from refiners.foundationals.segment_anything.transformer import (
|
|||
|
||||
|
||||
class EmbeddingsAggregator(fl.ContextModule):
|
||||
def __init__(self, num_output_mask: int = 3) -> None:
|
||||
super().__init__()
|
||||
self.num_mask_tokens = num_output_mask
|
||||
|
||||
def forward(self, iou_mask_tokens: Tensor) -> Tensor:
|
||||
mask_decoder = self.ensure_parent
|
||||
mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder")
|
||||
|
@ -48,7 +44,7 @@ class Hypernetworks(fl.Concatenate):
|
|||
self,
|
||||
embedding_dim: int = 256,
|
||||
num_layers: int = 3,
|
||||
num_mask_tokens: int = 3,
|
||||
num_mask_tokens: int = 4,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
|
@ -70,7 +66,7 @@ class Hypernetworks(fl.Concatenate):
|
|||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
for i in range(num_mask_tokens + 1)
|
||||
for i in range(num_mask_tokens)
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
@ -138,6 +134,7 @@ class MaskPrediction(fl.Chain):
|
|||
self,
|
||||
embedding_dim: int,
|
||||
num_mask_tokens: int,
|
||||
multimask_output: bool,
|
||||
num_layers: int = 3,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
|
@ -145,6 +142,10 @@ class MaskPrediction(fl.Chain):
|
|||
self.embedding_dim = embedding_dim
|
||||
self.num_mask_tokens = num_mask_tokens
|
||||
self.num_layers = num_layers
|
||||
self.multimask_output = multimask_output
|
||||
|
||||
start_mask, num_masks = (1, num_mask_tokens - 1) if multimask_output else (0, 1)
|
||||
|
||||
super().__init__(
|
||||
fl.Matmul(
|
||||
input=Hypernetworks(
|
||||
|
@ -156,8 +157,8 @@ class MaskPrediction(fl.Chain):
|
|||
),
|
||||
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||
),
|
||||
fl.Slicing(dim=1, start=1),
|
||||
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim),
|
||||
fl.Slicing(dim=1, start=start_mask, end=start_mask + num_masks),
|
||||
fl.Reshape(num_masks, embedding_dim, embedding_dim),
|
||||
)
|
||||
|
||||
|
||||
|
@ -167,47 +168,53 @@ class IOUPrediction(fl.Chain):
|
|||
embedding_dim: int,
|
||||
num_layers: int,
|
||||
num_mask_tokens: int,
|
||||
multimask_output: bool,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_layers = num_layers
|
||||
self.multimask_output = multimask_output
|
||||
|
||||
super().__init__(
|
||||
fl.Slicing(dim=1, start=0, end=1),
|
||||
fl.Squeeze(dim=0),
|
||||
fl.MultiLinear(
|
||||
input_dim=embedding_dim,
|
||||
output_dim=num_mask_tokens + 1,
|
||||
output_dim=num_mask_tokens,
|
||||
inner_dim=embedding_dim,
|
||||
num_layers=num_layers,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Slicing(dim=-1, start=1),
|
||||
fl.Slicing(dim=-1, start=1) if multimask_output else fl.Slicing(dim=-1, start=0, end=1),
|
||||
)
|
||||
|
||||
|
||||
class MaskDecoder(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
multimask_output: bool = True,
|
||||
embedding_dim: int = 256,
|
||||
feed_forward_dim: int = 2048,
|
||||
num_layers: int = 2,
|
||||
num_output_mask: int = 3,
|
||||
num_multimask_outputs: int = 3,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.multimask_output = multimask_output
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_mask_tokens = num_output_mask
|
||||
self.feed_forward_dim = feed_forward_dim
|
||||
self.num_layers = num_layers
|
||||
self.num_multimask_outputs = num_multimask_outputs
|
||||
|
||||
# The 1 additional token is for single-output mask prediction
|
||||
num_mask_tokens = self.num_multimask_outputs + 1
|
||||
|
||||
super().__init__(
|
||||
IOUMaskEncoder(
|
||||
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask + 1, device=device, dtype=dtype
|
||||
),
|
||||
EmbeddingsAggregator(num_output_mask=num_output_mask),
|
||||
IOUMaskEncoder(embedding_dim=embedding_dim, num_mask_tokens=num_mask_tokens, device=device, dtype=dtype),
|
||||
EmbeddingsAggregator(),
|
||||
Transformer(
|
||||
*(
|
||||
TwoWayTransformerLayer(
|
||||
|
@ -225,12 +232,17 @@ class MaskDecoder(fl.Chain):
|
|||
),
|
||||
fl.Parallel(
|
||||
MaskPrediction(
|
||||
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask, device=device, dtype=dtype
|
||||
embedding_dim=embedding_dim,
|
||||
num_mask_tokens=num_mask_tokens,
|
||||
multimask_output=multimask_output,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
IOUPrediction(
|
||||
embedding_dim=embedding_dim,
|
||||
num_layers=3,
|
||||
num_mask_tokens=num_output_mask,
|
||||
num_mask_tokens=num_mask_tokens,
|
||||
multimask_output=multimask_output,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
|
|
|
@ -233,6 +233,7 @@ class SegmentAnythingH(SegmentAnything):
|
|||
point_encoder: PointEncoder | None = None,
|
||||
mask_encoder: MaskEncoder | None = None,
|
||||
mask_decoder: MaskDecoder | None = None,
|
||||
multimask_output: bool | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
|
@ -243,13 +244,20 @@ class SegmentAnythingH(SegmentAnything):
|
|||
point_encoder: The point encoder to use.
|
||||
mask_encoder: The mask encoder to use.
|
||||
mask_decoder: The mask decoder to use.
|
||||
multimask_output: Whether to use multimask output.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
image_encoder = image_encoder or SAMViTH()
|
||||
point_encoder = point_encoder or PointEncoder()
|
||||
mask_encoder = mask_encoder or MaskEncoder()
|
||||
mask_decoder = mask_decoder or MaskDecoder()
|
||||
|
||||
if mask_decoder:
|
||||
assert (
|
||||
mask_decoder.multimask_output == multimask_output
|
||||
), f"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output (${multimask_output})"
|
||||
else:
|
||||
mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()
|
||||
|
||||
super().__init__(
|
||||
image_encoder=image_encoder,
|
||||
|
|
|
@ -90,6 +90,13 @@ def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
|
|||
return sam_h
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
|
||||
sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
|
||||
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
|
||||
return sam_h
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ref_path(test_sam_path: Path) -> Path:
|
||||
return test_sam_path / "test_sam_ref"
|
||||
|
@ -391,6 +398,34 @@ def test_predictor_dense_mask(
|
|||
assert isclose(intersection_over_union(dense_mask_prediction, facebook_dense_mask), 1.0, rel_tol=5e-05)
|
||||
|
||||
|
||||
def test_predictor_single_output(
|
||||
facebook_sam_h_predictor: FacebookSAMPredictor,
|
||||
sam_h_single_output: SegmentAnythingH,
|
||||
truck: Image.Image,
|
||||
one_prompt: SAMPrompt,
|
||||
) -> None:
|
||||
predictor = facebook_sam_h_predictor
|
||||
predictor.set_image(np.array(truck))
|
||||
|
||||
facebook_masks, facebook_scores, _ = predictor.predict( # type: ignore
|
||||
**one_prompt.facebook_predict_kwargs(), # type: ignore
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
assert len(facebook_masks) == 1
|
||||
|
||||
masks, scores, _ = sam_h_single_output.predict(truck, **one_prompt.__dict__)
|
||||
masks = masks.squeeze(0)
|
||||
scores = scores.squeeze(0)
|
||||
|
||||
assert len(masks) == 1
|
||||
|
||||
mask_prediction = masks[0].cpu()
|
||||
facebook_mask = torch.as_tensor(facebook_masks[0])
|
||||
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
|
||||
assert isclose(scores[0].item(), facebook_scores[0].item(), rel_tol=1e-05)
|
||||
|
||||
|
||||
def test_mask_encoder(
|
||||
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
|
||||
) -> None:
|
||||
|
|
Loading…
Reference in a new issue