From 00f494efe26cb84768b8143919a4f7547239343b Mon Sep 17 00:00:00 2001 From: hugojarkoff Date: Fri, 5 Jan 2024 18:45:03 +0100 Subject: [PATCH] SegmentAnything: add dense mask prompt support --- .../conversion/convert_segment_anything.py | 29 ++++- .../foundationals/segment_anything/model.py | 13 +-- .../segment_anything/test_sam.py | 106 +++++++++++++++++- tests/foundationals/segment_anything/utils.py | 15 ++- 4 files changed, 143 insertions(+), 20 deletions(-) diff --git a/scripts/conversion/convert_segment_anything.py b/scripts/conversion/convert_segment_anything.py index 9057cbb..14ba2ef 100644 --- a/scripts/conversion/convert_segment_anything.py +++ b/scripts/conversion/convert_segment_anything.py @@ -37,13 +37,36 @@ class Args(argparse.Namespace): def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]: + manual_seed(seed=0) + refiners_mask_encoder = MaskEncoder() + + converter = ModelConverter( + source_model=prompt_encoder.mask_downscaling, + target_model=refiners_mask_encoder, + custom_layer_mapping=custom_layers, # type: ignore + ) + + x = torch.randn(1, 256, 256) + mapping = converter.map_state_dicts(source_args=(x,)) + assert mapping + + source_state_dict = prompt_encoder.mask_downscaling.state_dict() + target_state_dict = refiners_mask_encoder.state_dict() + + # Mapping handled manually (see below) because nn.Parameter is a special case + del target_state_dict["no_mask_embedding"] + + converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage] + source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping + ) + state_dict: dict[str, Tensor] = { "no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore } - refiners_mask_encoder = MaskEncoder() - # TODO: handle other weights - refiners_mask_encoder.load_state_dict(state_dict=state_dict, strict=False) + state_dict.update(converted_source) + + refiners_mask_encoder.load_state_dict(state_dict=state_dict) return state_dict diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py index a7da415..905c4b6 100644 --- a/src/refiners/foundationals/segment_anything/model.py +++ b/src/refiners/foundationals/segment_anything/model.py @@ -3,11 +3,12 @@ from typing import Sequence import numpy as np import torch +from jaxtyping import Float from PIL import Image from torch import Tensor, device as Device, dtype as DType import refiners.fluxion.layers as fl -from refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad +from refiners.fluxion.utils import interpolate, no_grad, normalize, pad from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder @@ -55,7 +56,7 @@ class SegmentAnything(fl.Module): foreground_points: Sequence[tuple[float, float]] | None = None, background_points: Sequence[tuple[float, float]] | None = None, box_points: Sequence[Sequence[tuple[float, float]]] | None = None, - masks: Sequence[Image.Image] | None = None, + low_res_mask: Float[Tensor, "1 1 256 256"] | None = None, binarize: bool = True, ) -> tuple[Tensor, Tensor, Tensor]: if isinstance(input, ImageEmbedding): @@ -74,15 +75,13 @@ class SegmentAnything(fl.Module): ) self.point_encoder.set_type_mask(type_mask=type_mask) - if masks is not None: - mask_tensor = torch.stack( - tensors=[image_to_tensor(image=mask, device=self.device, dtype=self.dtype) for mask in masks] - ) - mask_embedding = self.mask_encoder(mask_tensor) + if low_res_mask is not None: + mask_embedding = self.mask_encoder(low_res_mask) else: mask_embedding = self.mask_encoder.get_no_mask_dense_embedding( image_embedding_size=self.image_encoder.image_embedding_size ) + point_embedding = self.point_encoder( self.normalize(coordinates, target_size=target_size, original_size=original_size) ) diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index de4147e..1e56685 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -23,7 +23,7 @@ from refiners.foundationals.segment_anything.image_encoder import FusedSelfAtten from refiners.foundationals.segment_anything.model import SegmentAnythingH from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer -# See predictor_example.ipynb official notebook (note: mask_input is not yet properly supported) +# See predictor_example.ipynb official notebook PROMPTS: list[SAMPrompt] = [ SAMPrompt(foreground_points=((500, 375),)), SAMPrompt(background_points=((500, 375),)), @@ -41,7 +41,9 @@ def prompt(request: pytest.FixtureRequest) -> SAMPrompt: @pytest.fixture def one_prompt() -> SAMPrompt: - return PROMPTS[0] + # Using the third prompt of the PROMPTS list in order to strictly do the same test as the official notebook in the + # test_predictor_dense_mask test. + return PROMPTS[2] @pytest.fixture(scope="module") @@ -83,8 +85,7 @@ def facebook_sam_h_predictor(facebook_sam_h: FacebookSAM) -> FacebookSAMPredicto @pytest.fixture(scope="module") def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH: sam_h = SegmentAnythingH(device=test_device) - # TODO: make strict=True when the MasKEncoder conversion is done - sam_h.load_from_safetensors(tensors_path=sam_h_weights, strict=False) + sam_h.load_from_safetensors(tensors_path=sam_h_weights) return sam_h @@ -164,7 +165,14 @@ def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, pro **prompt.facebook_prompt_encoder_kwargs(device=facebook_sam_h.device) ) - coordinates, type_mask = refiners_prompt_encoder.points_to_tensor(**prompt.__dict__) + prompt_dict = prompt.__dict__ + # Skip mask prompt, if any, since the point encoder only consumes points and boxes + # TODO: split `SAMPrompt` and introduce a dedicated one for dense prompts + prompt_dict.pop("low_res_mask", None) + + assert prompt_dict is not None, "`test_point_encoder` cannot be called with just a `low_res_mask`" + + coordinates, type_mask = refiners_prompt_encoder.points_to_tensor(**prompt_dict) # Shift to center of pixel + normalize in [0, 1] (see `_embed_points` in segment-anything official repo) coordinates[:, :, 0] = (coordinates[:, :, 0] + 0.5) / 1024.0 coordinates[:, :, 1] = (coordinates[:, :, 1] + 0.5) / 1024.0 @@ -319,3 +327,91 @@ def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, assert torch.equal(masks, masks_ref) assert torch.equal(scores_ref, scores) + + +def test_predictor_dense_mask( + facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt +) -> None: + """ + NOTE : Binarizing intermediate masks isn't necessary, as per SamPredictor.predict_torch docstring: + > mask_input (np.ndarray): A low resolution mask input to the model, typically + > coming from a previous prediction iteration. Has form Bx1xHxW, where + > for SAM, H=W=256. Masks returned by a previous iteration of the + > predict method do not need further transformation. + """ + predictor = facebook_sam_h_predictor + predictor.set_image(np.array(truck)) + facebook_masks, facebook_scores, facebook_logits = predictor.predict( + **one_prompt.facebook_predict_kwargs(), # type: ignore + multimask_output=True, + ) + + assert len(facebook_masks) == 3 + + facebook_mask_input = facebook_logits[np.argmax(facebook_scores)] # shape: HxW + + # Using the same mask coordinates inputs as the official notebook + facebook_prompt = SAMPrompt( + foreground_points=((500, 375),), background_points=((1125, 625),), low_res_mask=facebook_mask_input[None, ...] + ) + facebook_dense_masks, _, _ = predictor.predict(**facebook_prompt.facebook_predict_kwargs(), multimask_output=True) # type: ignore + + assert len(facebook_dense_masks) == 3 + + masks, scores, logits = sam_h.predict(truck, **one_prompt.__dict__) + masks = masks.squeeze(0) + scores = scores.squeeze(0) + + assert len(masks) == 3 + + mask_input = logits[:, scores.max(dim=0).indices, ...] # shape: 1xHxW + + assert np.allclose( + mask_input.cpu().numpy(), facebook_mask_input, atol=1e-1 + ) # Lower doesn't pass, but it's close enough for logits + + refiners_prompt = SAMPrompt( + foreground_points=((500, 375),), background_points=((1125, 625),), low_res_mask=mask_input.unsqueeze(0) + ) + dense_masks, _, _ = sam_h.predict(truck, **refiners_prompt.__dict__) + dense_masks = dense_masks.squeeze(0) + + assert len(dense_masks) == 3 + + for i in range(3): + dense_mask_prediction = dense_masks[i].cpu() + facebook_dense_mask = torch.as_tensor(facebook_dense_masks[i]) + assert dense_mask_prediction.shape == facebook_dense_mask.shape + assert isclose(intersection_over_union(dense_mask_prediction, facebook_dense_mask), 1.0, rel_tol=5e-05) + + +def test_mask_encoder( + facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt +) -> None: + predictor = facebook_sam_h_predictor + predictor.set_image(np.array(truck)) + _, facebook_scores, facebook_logits = predictor.predict( + **one_prompt.facebook_predict_kwargs(), # type: ignore + multimask_output=True, + ) + facebook_mask_input = facebook_logits[np.argmax(facebook_scores)] + facebook_mask_input = ( + torch.from_numpy(facebook_mask_input) # type: ignore + .to(device=predictor.model.device) + .unsqueeze(0) + .unsqueeze(0) # shape: 1x1xHxW + ) + + _, fb_dense_embeddings = predictor.model.prompt_encoder( + points=None, + boxes=None, + masks=facebook_mask_input, + ) + + _, scores, logits = sam_h.predict(truck, **one_prompt.__dict__) + scores = scores.squeeze(0) + mask_input = logits[:, scores.max(dim=0).indices, ...].unsqueeze(0) # shape: 1x1xHxW + dense_embeddings = sam_h.mask_encoder(mask_input) + + assert facebook_mask_input.shape == mask_input.shape + assert torch.allclose(dense_embeddings, fb_dense_embeddings, atol=1e-4, rtol=1e-4) diff --git a/tests/foundationals/segment_anything/utils.py b/tests/foundationals/segment_anything/utils.py index ef73e36..fa18e88 100644 --- a/tests/foundationals/segment_anything/utils.py +++ b/tests/foundationals/segment_anything/utils.py @@ -63,8 +63,7 @@ class SAMPrompt: foreground_points: Sequence[tuple[float, float]] | None = None background_points: Sequence[tuple[float, float]] | None = None box_points: Sequence[Sequence[tuple[float, float]]] | None = None - # TODO: support masks - # masks: Sequence[Image.Image] | None = None + low_res_mask: Tensor | None = None def facebook_predict_kwargs(self) -> dict[str, NDArray]: prompt: dict[str, NDArray] = {} @@ -85,13 +84,18 @@ class SAMPrompt: prompt["box"] = np.array([coord for batch in self.box_points for xy in batch for coord in xy]).reshape( len(self.box_points), 4 ) + if self.low_res_mask is not None: + prompt["mask_input"] = np.array(self.low_res_mask) return prompt - def facebook_prompt_encoder_kwargs(self, device: torch.device | None = None): + def facebook_prompt_encoder_kwargs( + self, device: torch.device | None = None + ) -> dict[str, Tensor | tuple[Tensor, Tensor | None] | None]: prompt = self.facebook_predict_kwargs() coords: Tensor | None = None labels: Tensor | None = None boxes: Tensor | None = None + masks: Tensor | None = None if "point_coords" in prompt: coords = torch.as_tensor(prompt["point_coords"], dtype=torch.float, device=device).unsqueeze(0) if "point_labels" in prompt: @@ -99,8 +103,9 @@ class SAMPrompt: if "box" in prompt: boxes = torch.as_tensor(prompt["box"], dtype=torch.float, device=device).unsqueeze(0) points = (coords, labels) if coords is not None else None - # TODO: support masks - return {"points": points, "boxes": boxes, "masks": None} + if "mask_input" in prompt: + masks = torch.as_tensor(prompt["mask_input"], dtype=torch.float, device=device).unsqueeze(0) + return {"points": points, "boxes": boxes, "masks": masks} def intersection_over_union(