SegmentAnything: add dense mask prompt support

This commit is contained in:
hugojarkoff 2024-01-05 18:45:03 +01:00 committed by hugojarkoff
parent 20c229903f
commit 00f494efe2
4 changed files with 143 additions and 20 deletions

View file

@ -37,13 +37,36 @@ class Args(argparse.Namespace):
def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]: def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
manual_seed(seed=0)
refiners_mask_encoder = MaskEncoder()
converter = ModelConverter(
source_model=prompt_encoder.mask_downscaling,
target_model=refiners_mask_encoder,
custom_layer_mapping=custom_layers, # type: ignore
)
x = torch.randn(1, 256, 256)
mapping = converter.map_state_dicts(source_args=(x,))
assert mapping
source_state_dict = prompt_encoder.mask_downscaling.state_dict()
target_state_dict = refiners_mask_encoder.state_dict()
# Mapping handled manually (see below) because nn.Parameter is a special case
del target_state_dict["no_mask_embedding"]
converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
)
state_dict: dict[str, Tensor] = { state_dict: dict[str, Tensor] = {
"no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore "no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore
} }
refiners_mask_encoder = MaskEncoder() state_dict.update(converted_source)
# TODO: handle other weights
refiners_mask_encoder.load_state_dict(state_dict=state_dict, strict=False) refiners_mask_encoder.load_state_dict(state_dict=state_dict)
return state_dict return state_dict

View file

@ -3,11 +3,12 @@ from typing import Sequence
import numpy as np import numpy as np
import torch import torch
from jaxtyping import Float
from PIL import Image from PIL import Image
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad from refiners.fluxion.utils import interpolate, no_grad, normalize, pad
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
@ -55,7 +56,7 @@ class SegmentAnything(fl.Module):
foreground_points: Sequence[tuple[float, float]] | None = None, foreground_points: Sequence[tuple[float, float]] | None = None,
background_points: Sequence[tuple[float, float]] | None = None, background_points: Sequence[tuple[float, float]] | None = None,
box_points: Sequence[Sequence[tuple[float, float]]] | None = None, box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
masks: Sequence[Image.Image] | None = None, low_res_mask: Float[Tensor, "1 1 256 256"] | None = None,
binarize: bool = True, binarize: bool = True,
) -> tuple[Tensor, Tensor, Tensor]: ) -> tuple[Tensor, Tensor, Tensor]:
if isinstance(input, ImageEmbedding): if isinstance(input, ImageEmbedding):
@ -74,15 +75,13 @@ class SegmentAnything(fl.Module):
) )
self.point_encoder.set_type_mask(type_mask=type_mask) self.point_encoder.set_type_mask(type_mask=type_mask)
if masks is not None: if low_res_mask is not None:
mask_tensor = torch.stack( mask_embedding = self.mask_encoder(low_res_mask)
tensors=[image_to_tensor(image=mask, device=self.device, dtype=self.dtype) for mask in masks]
)
mask_embedding = self.mask_encoder(mask_tensor)
else: else:
mask_embedding = self.mask_encoder.get_no_mask_dense_embedding( mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(
image_embedding_size=self.image_encoder.image_embedding_size image_embedding_size=self.image_encoder.image_embedding_size
) )
point_embedding = self.point_encoder( point_embedding = self.point_encoder(
self.normalize(coordinates, target_size=target_size, original_size=original_size) self.normalize(coordinates, target_size=target_size, original_size=original_size)
) )

View file

@ -23,7 +23,7 @@ from refiners.foundationals.segment_anything.image_encoder import FusedSelfAtten
from refiners.foundationals.segment_anything.model import SegmentAnythingH from refiners.foundationals.segment_anything.model import SegmentAnythingH
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
# See predictor_example.ipynb official notebook (note: mask_input is not yet properly supported) # See predictor_example.ipynb official notebook
PROMPTS: list[SAMPrompt] = [ PROMPTS: list[SAMPrompt] = [
SAMPrompt(foreground_points=((500, 375),)), SAMPrompt(foreground_points=((500, 375),)),
SAMPrompt(background_points=((500, 375),)), SAMPrompt(background_points=((500, 375),)),
@ -41,7 +41,9 @@ def prompt(request: pytest.FixtureRequest) -> SAMPrompt:
@pytest.fixture @pytest.fixture
def one_prompt() -> SAMPrompt: def one_prompt() -> SAMPrompt:
return PROMPTS[0] # Using the third prompt of the PROMPTS list in order to strictly do the same test as the official notebook in the
# test_predictor_dense_mask test.
return PROMPTS[2]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -83,8 +85,7 @@ def facebook_sam_h_predictor(facebook_sam_h: FacebookSAM) -> FacebookSAMPredicto
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH: def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
sam_h = SegmentAnythingH(device=test_device) sam_h = SegmentAnythingH(device=test_device)
# TODO: make strict=True when the MasKEncoder conversion is done sam_h.load_from_safetensors(tensors_path=sam_h_weights)
sam_h.load_from_safetensors(tensors_path=sam_h_weights, strict=False)
return sam_h return sam_h
@ -164,7 +165,14 @@ def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, pro
**prompt.facebook_prompt_encoder_kwargs(device=facebook_sam_h.device) **prompt.facebook_prompt_encoder_kwargs(device=facebook_sam_h.device)
) )
coordinates, type_mask = refiners_prompt_encoder.points_to_tensor(**prompt.__dict__) prompt_dict = prompt.__dict__
# Skip mask prompt, if any, since the point encoder only consumes points and boxes
# TODO: split `SAMPrompt` and introduce a dedicated one for dense prompts
prompt_dict.pop("low_res_mask", None)
assert prompt_dict is not None, "`test_point_encoder` cannot be called with just a `low_res_mask`"
coordinates, type_mask = refiners_prompt_encoder.points_to_tensor(**prompt_dict)
# Shift to center of pixel + normalize in [0, 1] (see `_embed_points` in segment-anything official repo) # Shift to center of pixel + normalize in [0, 1] (see `_embed_points` in segment-anything official repo)
coordinates[:, :, 0] = (coordinates[:, :, 0] + 0.5) / 1024.0 coordinates[:, :, 0] = (coordinates[:, :, 0] + 0.5) / 1024.0
coordinates[:, :, 1] = (coordinates[:, :, 1] + 0.5) / 1024.0 coordinates[:, :, 1] = (coordinates[:, :, 1] + 0.5) / 1024.0
@ -319,3 +327,91 @@ def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image,
assert torch.equal(masks, masks_ref) assert torch.equal(masks, masks_ref)
assert torch.equal(scores_ref, scores) assert torch.equal(scores_ref, scores)
def test_predictor_dense_mask(
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
) -> None:
"""
NOTE : Binarizing intermediate masks isn't necessary, as per SamPredictor.predict_torch docstring:
> mask_input (np.ndarray): A low resolution mask input to the model, typically
> coming from a previous prediction iteration. Has form Bx1xHxW, where
> for SAM, H=W=256. Masks returned by a previous iteration of the
> predict method do not need further transformation.
"""
predictor = facebook_sam_h_predictor
predictor.set_image(np.array(truck))
facebook_masks, facebook_scores, facebook_logits = predictor.predict(
**one_prompt.facebook_predict_kwargs(), # type: ignore
multimask_output=True,
)
assert len(facebook_masks) == 3
facebook_mask_input = facebook_logits[np.argmax(facebook_scores)] # shape: HxW
# Using the same mask coordinates inputs as the official notebook
facebook_prompt = SAMPrompt(
foreground_points=((500, 375),), background_points=((1125, 625),), low_res_mask=facebook_mask_input[None, ...]
)
facebook_dense_masks, _, _ = predictor.predict(**facebook_prompt.facebook_predict_kwargs(), multimask_output=True) # type: ignore
assert len(facebook_dense_masks) == 3
masks, scores, logits = sam_h.predict(truck, **one_prompt.__dict__)
masks = masks.squeeze(0)
scores = scores.squeeze(0)
assert len(masks) == 3
mask_input = logits[:, scores.max(dim=0).indices, ...] # shape: 1xHxW
assert np.allclose(
mask_input.cpu().numpy(), facebook_mask_input, atol=1e-1
) # Lower doesn't pass, but it's close enough for logits
refiners_prompt = SAMPrompt(
foreground_points=((500, 375),), background_points=((1125, 625),), low_res_mask=mask_input.unsqueeze(0)
)
dense_masks, _, _ = sam_h.predict(truck, **refiners_prompt.__dict__)
dense_masks = dense_masks.squeeze(0)
assert len(dense_masks) == 3
for i in range(3):
dense_mask_prediction = dense_masks[i].cpu()
facebook_dense_mask = torch.as_tensor(facebook_dense_masks[i])
assert dense_mask_prediction.shape == facebook_dense_mask.shape
assert isclose(intersection_over_union(dense_mask_prediction, facebook_dense_mask), 1.0, rel_tol=5e-05)
def test_mask_encoder(
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
) -> None:
predictor = facebook_sam_h_predictor
predictor.set_image(np.array(truck))
_, facebook_scores, facebook_logits = predictor.predict(
**one_prompt.facebook_predict_kwargs(), # type: ignore
multimask_output=True,
)
facebook_mask_input = facebook_logits[np.argmax(facebook_scores)]
facebook_mask_input = (
torch.from_numpy(facebook_mask_input) # type: ignore
.to(device=predictor.model.device)
.unsqueeze(0)
.unsqueeze(0) # shape: 1x1xHxW
)
_, fb_dense_embeddings = predictor.model.prompt_encoder(
points=None,
boxes=None,
masks=facebook_mask_input,
)
_, scores, logits = sam_h.predict(truck, **one_prompt.__dict__)
scores = scores.squeeze(0)
mask_input = logits[:, scores.max(dim=0).indices, ...].unsqueeze(0) # shape: 1x1xHxW
dense_embeddings = sam_h.mask_encoder(mask_input)
assert facebook_mask_input.shape == mask_input.shape
assert torch.allclose(dense_embeddings, fb_dense_embeddings, atol=1e-4, rtol=1e-4)

View file

@ -63,8 +63,7 @@ class SAMPrompt:
foreground_points: Sequence[tuple[float, float]] | None = None foreground_points: Sequence[tuple[float, float]] | None = None
background_points: Sequence[tuple[float, float]] | None = None background_points: Sequence[tuple[float, float]] | None = None
box_points: Sequence[Sequence[tuple[float, float]]] | None = None box_points: Sequence[Sequence[tuple[float, float]]] | None = None
# TODO: support masks low_res_mask: Tensor | None = None
# masks: Sequence[Image.Image] | None = None
def facebook_predict_kwargs(self) -> dict[str, NDArray]: def facebook_predict_kwargs(self) -> dict[str, NDArray]:
prompt: dict[str, NDArray] = {} prompt: dict[str, NDArray] = {}
@ -85,13 +84,18 @@ class SAMPrompt:
prompt["box"] = np.array([coord for batch in self.box_points for xy in batch for coord in xy]).reshape( prompt["box"] = np.array([coord for batch in self.box_points for xy in batch for coord in xy]).reshape(
len(self.box_points), 4 len(self.box_points), 4
) )
if self.low_res_mask is not None:
prompt["mask_input"] = np.array(self.low_res_mask)
return prompt return prompt
def facebook_prompt_encoder_kwargs(self, device: torch.device | None = None): def facebook_prompt_encoder_kwargs(
self, device: torch.device | None = None
) -> dict[str, Tensor | tuple[Tensor, Tensor | None] | None]:
prompt = self.facebook_predict_kwargs() prompt = self.facebook_predict_kwargs()
coords: Tensor | None = None coords: Tensor | None = None
labels: Tensor | None = None labels: Tensor | None = None
boxes: Tensor | None = None boxes: Tensor | None = None
masks: Tensor | None = None
if "point_coords" in prompt: if "point_coords" in prompt:
coords = torch.as_tensor(prompt["point_coords"], dtype=torch.float, device=device).unsqueeze(0) coords = torch.as_tensor(prompt["point_coords"], dtype=torch.float, device=device).unsqueeze(0)
if "point_labels" in prompt: if "point_labels" in prompt:
@ -99,8 +103,9 @@ class SAMPrompt:
if "box" in prompt: if "box" in prompt:
boxes = torch.as_tensor(prompt["box"], dtype=torch.float, device=device).unsqueeze(0) boxes = torch.as_tensor(prompt["box"], dtype=torch.float, device=device).unsqueeze(0)
points = (coords, labels) if coords is not None else None points = (coords, labels) if coords is not None else None
# TODO: support masks if "mask_input" in prompt:
return {"points": points, "boxes": boxes, "masks": None} masks = torch.as_tensor(prompt["mask_input"], dtype=torch.float, device=device).unsqueeze(0)
return {"points": points, "boxes": boxes, "masks": masks}
def intersection_over_union( def intersection_over_union(