mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
Add multimask_output flag to SAM
This commit is contained in:
parent
6a72943ff7
commit
68fe725767
|
@ -10,10 +10,6 @@ from refiners.foundationals.segment_anything.transformer import (
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsAggregator(fl.ContextModule):
|
class EmbeddingsAggregator(fl.ContextModule):
|
||||||
def __init__(self, num_output_mask: int = 3) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.num_mask_tokens = num_output_mask
|
|
||||||
|
|
||||||
def forward(self, iou_mask_tokens: Tensor) -> Tensor:
|
def forward(self, iou_mask_tokens: Tensor) -> Tensor:
|
||||||
mask_decoder = self.ensure_parent
|
mask_decoder = self.ensure_parent
|
||||||
mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder")
|
mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder")
|
||||||
|
@ -48,7 +44,7 @@ class Hypernetworks(fl.Concatenate):
|
||||||
self,
|
self,
|
||||||
embedding_dim: int = 256,
|
embedding_dim: int = 256,
|
||||||
num_layers: int = 3,
|
num_layers: int = 3,
|
||||||
num_mask_tokens: int = 3,
|
num_mask_tokens: int = 4,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -70,7 +66,7 @@ class Hypernetworks(fl.Concatenate):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for i in range(num_mask_tokens + 1)
|
for i in range(num_mask_tokens)
|
||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
|
@ -138,6 +134,7 @@ class MaskPrediction(fl.Chain):
|
||||||
self,
|
self,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
num_mask_tokens: int,
|
num_mask_tokens: int,
|
||||||
|
multimask_output: bool,
|
||||||
num_layers: int = 3,
|
num_layers: int = 3,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
|
@ -145,6 +142,10 @@ class MaskPrediction(fl.Chain):
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.num_mask_tokens = num_mask_tokens
|
self.num_mask_tokens = num_mask_tokens
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
self.multimask_output = multimask_output
|
||||||
|
|
||||||
|
start_mask, num_masks = (1, num_mask_tokens - 1) if multimask_output else (0, 1)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Matmul(
|
fl.Matmul(
|
||||||
input=Hypernetworks(
|
input=Hypernetworks(
|
||||||
|
@ -156,8 +157,8 @@ class MaskPrediction(fl.Chain):
|
||||||
),
|
),
|
||||||
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||||
),
|
),
|
||||||
fl.Slicing(dim=1, start=1),
|
fl.Slicing(dim=1, start=start_mask, end=start_mask + num_masks),
|
||||||
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim),
|
fl.Reshape(num_masks, embedding_dim, embedding_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -167,47 +168,53 @@ class IOUPrediction(fl.Chain):
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
num_mask_tokens: int,
|
num_mask_tokens: int,
|
||||||
|
multimask_output: bool,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
self.multimask_output = multimask_output
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Slicing(dim=1, start=0, end=1),
|
fl.Slicing(dim=1, start=0, end=1),
|
||||||
fl.Squeeze(dim=0),
|
fl.Squeeze(dim=0),
|
||||||
fl.MultiLinear(
|
fl.MultiLinear(
|
||||||
input_dim=embedding_dim,
|
input_dim=embedding_dim,
|
||||||
output_dim=num_mask_tokens + 1,
|
output_dim=num_mask_tokens,
|
||||||
inner_dim=embedding_dim,
|
inner_dim=embedding_dim,
|
||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
fl.Slicing(dim=-1, start=1),
|
fl.Slicing(dim=-1, start=1) if multimask_output else fl.Slicing(dim=-1, start=0, end=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MaskDecoder(fl.Chain):
|
class MaskDecoder(fl.Chain):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
multimask_output: bool = True,
|
||||||
embedding_dim: int = 256,
|
embedding_dim: int = 256,
|
||||||
feed_forward_dim: int = 2048,
|
feed_forward_dim: int = 2048,
|
||||||
num_layers: int = 2,
|
num_layers: int = 2,
|
||||||
num_output_mask: int = 3,
|
num_multimask_outputs: int = 3,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.multimask_output = multimask_output
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.num_mask_tokens = num_output_mask
|
|
||||||
self.feed_forward_dim = feed_forward_dim
|
self.feed_forward_dim = feed_forward_dim
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
self.num_multimask_outputs = num_multimask_outputs
|
||||||
|
|
||||||
|
# The 1 additional token is for single-output mask prediction
|
||||||
|
num_mask_tokens = self.num_multimask_outputs + 1
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
IOUMaskEncoder(
|
IOUMaskEncoder(embedding_dim=embedding_dim, num_mask_tokens=num_mask_tokens, device=device, dtype=dtype),
|
||||||
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask + 1, device=device, dtype=dtype
|
EmbeddingsAggregator(),
|
||||||
),
|
|
||||||
EmbeddingsAggregator(num_output_mask=num_output_mask),
|
|
||||||
Transformer(
|
Transformer(
|
||||||
*(
|
*(
|
||||||
TwoWayTransformerLayer(
|
TwoWayTransformerLayer(
|
||||||
|
@ -225,12 +232,17 @@ class MaskDecoder(fl.Chain):
|
||||||
),
|
),
|
||||||
fl.Parallel(
|
fl.Parallel(
|
||||||
MaskPrediction(
|
MaskPrediction(
|
||||||
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask, device=device, dtype=dtype
|
embedding_dim=embedding_dim,
|
||||||
|
num_mask_tokens=num_mask_tokens,
|
||||||
|
multimask_output=multimask_output,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
IOUPrediction(
|
IOUPrediction(
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
num_layers=3,
|
num_layers=3,
|
||||||
num_mask_tokens=num_output_mask,
|
num_mask_tokens=num_mask_tokens,
|
||||||
|
multimask_output=multimask_output,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
|
|
|
@ -233,6 +233,7 @@ class SegmentAnythingH(SegmentAnything):
|
||||||
point_encoder: PointEncoder | None = None,
|
point_encoder: PointEncoder | None = None,
|
||||||
mask_encoder: MaskEncoder | None = None,
|
mask_encoder: MaskEncoder | None = None,
|
||||||
mask_decoder: MaskDecoder | None = None,
|
mask_decoder: MaskDecoder | None = None,
|
||||||
|
multimask_output: bool | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = torch.float32,
|
dtype: DType = torch.float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -243,13 +244,20 @@ class SegmentAnythingH(SegmentAnything):
|
||||||
point_encoder: The point encoder to use.
|
point_encoder: The point encoder to use.
|
||||||
mask_encoder: The mask encoder to use.
|
mask_encoder: The mask encoder to use.
|
||||||
mask_decoder: The mask decoder to use.
|
mask_decoder: The mask decoder to use.
|
||||||
|
multimask_output: Whether to use multimask output.
|
||||||
device: The PyTorch device to use.
|
device: The PyTorch device to use.
|
||||||
dtype: The PyTorch data type to use.
|
dtype: The PyTorch data type to use.
|
||||||
"""
|
"""
|
||||||
image_encoder = image_encoder or SAMViTH()
|
image_encoder = image_encoder or SAMViTH()
|
||||||
point_encoder = point_encoder or PointEncoder()
|
point_encoder = point_encoder or PointEncoder()
|
||||||
mask_encoder = mask_encoder or MaskEncoder()
|
mask_encoder = mask_encoder or MaskEncoder()
|
||||||
mask_decoder = mask_decoder or MaskDecoder()
|
|
||||||
|
if mask_decoder:
|
||||||
|
assert (
|
||||||
|
mask_decoder.multimask_output == multimask_output
|
||||||
|
), f"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output (${multimask_output})"
|
||||||
|
else:
|
||||||
|
mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
image_encoder=image_encoder,
|
image_encoder=image_encoder,
|
||||||
|
|
|
@ -90,6 +90,13 @@ def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
|
||||||
return sam_h
|
return sam_h
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
|
||||||
|
sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
|
||||||
|
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
|
||||||
|
return sam_h
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ref_path(test_sam_path: Path) -> Path:
|
def ref_path(test_sam_path: Path) -> Path:
|
||||||
return test_sam_path / "test_sam_ref"
|
return test_sam_path / "test_sam_ref"
|
||||||
|
@ -391,6 +398,34 @@ def test_predictor_dense_mask(
|
||||||
assert isclose(intersection_over_union(dense_mask_prediction, facebook_dense_mask), 1.0, rel_tol=5e-05)
|
assert isclose(intersection_over_union(dense_mask_prediction, facebook_dense_mask), 1.0, rel_tol=5e-05)
|
||||||
|
|
||||||
|
|
||||||
|
def test_predictor_single_output(
|
||||||
|
facebook_sam_h_predictor: FacebookSAMPredictor,
|
||||||
|
sam_h_single_output: SegmentAnythingH,
|
||||||
|
truck: Image.Image,
|
||||||
|
one_prompt: SAMPrompt,
|
||||||
|
) -> None:
|
||||||
|
predictor = facebook_sam_h_predictor
|
||||||
|
predictor.set_image(np.array(truck))
|
||||||
|
|
||||||
|
facebook_masks, facebook_scores, _ = predictor.predict( # type: ignore
|
||||||
|
**one_prompt.facebook_predict_kwargs(), # type: ignore
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(facebook_masks) == 1
|
||||||
|
|
||||||
|
masks, scores, _ = sam_h_single_output.predict(truck, **one_prompt.__dict__)
|
||||||
|
masks = masks.squeeze(0)
|
||||||
|
scores = scores.squeeze(0)
|
||||||
|
|
||||||
|
assert len(masks) == 1
|
||||||
|
|
||||||
|
mask_prediction = masks[0].cpu()
|
||||||
|
facebook_mask = torch.as_tensor(facebook_masks[0])
|
||||||
|
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
|
||||||
|
assert isclose(scores[0].item(), facebook_scores[0].item(), rel_tol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
def test_mask_encoder(
|
def test_mask_encoder(
|
||||||
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
|
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
Loading…
Reference in a new issue