mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
SegmentAnything: add dense mask prompt support
This commit is contained in:
parent
20c229903f
commit
00f494efe2
|
@ -37,13 +37,36 @@ class Args(argparse.Namespace):
|
|||
|
||||
|
||||
def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
|
||||
manual_seed(seed=0)
|
||||
refiners_mask_encoder = MaskEncoder()
|
||||
|
||||
converter = ModelConverter(
|
||||
source_model=prompt_encoder.mask_downscaling,
|
||||
target_model=refiners_mask_encoder,
|
||||
custom_layer_mapping=custom_layers, # type: ignore
|
||||
)
|
||||
|
||||
x = torch.randn(1, 256, 256)
|
||||
mapping = converter.map_state_dicts(source_args=(x,))
|
||||
assert mapping
|
||||
|
||||
source_state_dict = prompt_encoder.mask_downscaling.state_dict()
|
||||
target_state_dict = refiners_mask_encoder.state_dict()
|
||||
|
||||
# Mapping handled manually (see below) because nn.Parameter is a special case
|
||||
del target_state_dict["no_mask_embedding"]
|
||||
|
||||
converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
|
||||
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
||||
)
|
||||
|
||||
state_dict: dict[str, Tensor] = {
|
||||
"no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore
|
||||
}
|
||||
|
||||
refiners_mask_encoder = MaskEncoder()
|
||||
# TODO: handle other weights
|
||||
refiners_mask_encoder.load_state_dict(state_dict=state_dict, strict=False)
|
||||
state_dict.update(converted_source)
|
||||
|
||||
refiners_mask_encoder.load_state_dict(state_dict=state_dict)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
|
|
@ -3,11 +3,12 @@ from typing import Sequence
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from PIL import Image
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad
|
||||
from refiners.fluxion.utils import interpolate, no_grad, normalize, pad
|
||||
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
|
||||
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
||||
|
@ -55,7 +56,7 @@ class SegmentAnything(fl.Module):
|
|||
foreground_points: Sequence[tuple[float, float]] | None = None,
|
||||
background_points: Sequence[tuple[float, float]] | None = None,
|
||||
box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
|
||||
masks: Sequence[Image.Image] | None = None,
|
||||
low_res_mask: Float[Tensor, "1 1 256 256"] | None = None,
|
||||
binarize: bool = True,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
if isinstance(input, ImageEmbedding):
|
||||
|
@ -74,15 +75,13 @@ class SegmentAnything(fl.Module):
|
|||
)
|
||||
self.point_encoder.set_type_mask(type_mask=type_mask)
|
||||
|
||||
if masks is not None:
|
||||
mask_tensor = torch.stack(
|
||||
tensors=[image_to_tensor(image=mask, device=self.device, dtype=self.dtype) for mask in masks]
|
||||
)
|
||||
mask_embedding = self.mask_encoder(mask_tensor)
|
||||
if low_res_mask is not None:
|
||||
mask_embedding = self.mask_encoder(low_res_mask)
|
||||
else:
|
||||
mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(
|
||||
image_embedding_size=self.image_encoder.image_embedding_size
|
||||
)
|
||||
|
||||
point_embedding = self.point_encoder(
|
||||
self.normalize(coordinates, target_size=target_size, original_size=original_size)
|
||||
)
|
||||
|
|
|
@ -23,7 +23,7 @@ from refiners.foundationals.segment_anything.image_encoder import FusedSelfAtten
|
|||
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
||||
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
|
||||
|
||||
# See predictor_example.ipynb official notebook (note: mask_input is not yet properly supported)
|
||||
# See predictor_example.ipynb official notebook
|
||||
PROMPTS: list[SAMPrompt] = [
|
||||
SAMPrompt(foreground_points=((500, 375),)),
|
||||
SAMPrompt(background_points=((500, 375),)),
|
||||
|
@ -41,7 +41,9 @@ def prompt(request: pytest.FixtureRequest) -> SAMPrompt:
|
|||
|
||||
@pytest.fixture
|
||||
def one_prompt() -> SAMPrompt:
|
||||
return PROMPTS[0]
|
||||
# Using the third prompt of the PROMPTS list in order to strictly do the same test as the official notebook in the
|
||||
# test_predictor_dense_mask test.
|
||||
return PROMPTS[2]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -83,8 +85,7 @@ def facebook_sam_h_predictor(facebook_sam_h: FacebookSAM) -> FacebookSAMPredicto
|
|||
@pytest.fixture(scope="module")
|
||||
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
|
||||
sam_h = SegmentAnythingH(device=test_device)
|
||||
# TODO: make strict=True when the MasKEncoder conversion is done
|
||||
sam_h.load_from_safetensors(tensors_path=sam_h_weights, strict=False)
|
||||
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
|
||||
return sam_h
|
||||
|
||||
|
||||
|
@ -164,7 +165,14 @@ def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, pro
|
|||
**prompt.facebook_prompt_encoder_kwargs(device=facebook_sam_h.device)
|
||||
)
|
||||
|
||||
coordinates, type_mask = refiners_prompt_encoder.points_to_tensor(**prompt.__dict__)
|
||||
prompt_dict = prompt.__dict__
|
||||
# Skip mask prompt, if any, since the point encoder only consumes points and boxes
|
||||
# TODO: split `SAMPrompt` and introduce a dedicated one for dense prompts
|
||||
prompt_dict.pop("low_res_mask", None)
|
||||
|
||||
assert prompt_dict is not None, "`test_point_encoder` cannot be called with just a `low_res_mask`"
|
||||
|
||||
coordinates, type_mask = refiners_prompt_encoder.points_to_tensor(**prompt_dict)
|
||||
# Shift to center of pixel + normalize in [0, 1] (see `_embed_points` in segment-anything official repo)
|
||||
coordinates[:, :, 0] = (coordinates[:, :, 0] + 0.5) / 1024.0
|
||||
coordinates[:, :, 1] = (coordinates[:, :, 1] + 0.5) / 1024.0
|
||||
|
@ -319,3 +327,91 @@ def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image,
|
|||
|
||||
assert torch.equal(masks, masks_ref)
|
||||
assert torch.equal(scores_ref, scores)
|
||||
|
||||
|
||||
def test_predictor_dense_mask(
|
||||
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
|
||||
) -> None:
|
||||
"""
|
||||
NOTE : Binarizing intermediate masks isn't necessary, as per SamPredictor.predict_torch docstring:
|
||||
> mask_input (np.ndarray): A low resolution mask input to the model, typically
|
||||
> coming from a previous prediction iteration. Has form Bx1xHxW, where
|
||||
> for SAM, H=W=256. Masks returned by a previous iteration of the
|
||||
> predict method do not need further transformation.
|
||||
"""
|
||||
predictor = facebook_sam_h_predictor
|
||||
predictor.set_image(np.array(truck))
|
||||
facebook_masks, facebook_scores, facebook_logits = predictor.predict(
|
||||
**one_prompt.facebook_predict_kwargs(), # type: ignore
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
assert len(facebook_masks) == 3
|
||||
|
||||
facebook_mask_input = facebook_logits[np.argmax(facebook_scores)] # shape: HxW
|
||||
|
||||
# Using the same mask coordinates inputs as the official notebook
|
||||
facebook_prompt = SAMPrompt(
|
||||
foreground_points=((500, 375),), background_points=((1125, 625),), low_res_mask=facebook_mask_input[None, ...]
|
||||
)
|
||||
facebook_dense_masks, _, _ = predictor.predict(**facebook_prompt.facebook_predict_kwargs(), multimask_output=True) # type: ignore
|
||||
|
||||
assert len(facebook_dense_masks) == 3
|
||||
|
||||
masks, scores, logits = sam_h.predict(truck, **one_prompt.__dict__)
|
||||
masks = masks.squeeze(0)
|
||||
scores = scores.squeeze(0)
|
||||
|
||||
assert len(masks) == 3
|
||||
|
||||
mask_input = logits[:, scores.max(dim=0).indices, ...] # shape: 1xHxW
|
||||
|
||||
assert np.allclose(
|
||||
mask_input.cpu().numpy(), facebook_mask_input, atol=1e-1
|
||||
) # Lower doesn't pass, but it's close enough for logits
|
||||
|
||||
refiners_prompt = SAMPrompt(
|
||||
foreground_points=((500, 375),), background_points=((1125, 625),), low_res_mask=mask_input.unsqueeze(0)
|
||||
)
|
||||
dense_masks, _, _ = sam_h.predict(truck, **refiners_prompt.__dict__)
|
||||
dense_masks = dense_masks.squeeze(0)
|
||||
|
||||
assert len(dense_masks) == 3
|
||||
|
||||
for i in range(3):
|
||||
dense_mask_prediction = dense_masks[i].cpu()
|
||||
facebook_dense_mask = torch.as_tensor(facebook_dense_masks[i])
|
||||
assert dense_mask_prediction.shape == facebook_dense_mask.shape
|
||||
assert isclose(intersection_over_union(dense_mask_prediction, facebook_dense_mask), 1.0, rel_tol=5e-05)
|
||||
|
||||
|
||||
def test_mask_encoder(
|
||||
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
|
||||
) -> None:
|
||||
predictor = facebook_sam_h_predictor
|
||||
predictor.set_image(np.array(truck))
|
||||
_, facebook_scores, facebook_logits = predictor.predict(
|
||||
**one_prompt.facebook_predict_kwargs(), # type: ignore
|
||||
multimask_output=True,
|
||||
)
|
||||
facebook_mask_input = facebook_logits[np.argmax(facebook_scores)]
|
||||
facebook_mask_input = (
|
||||
torch.from_numpy(facebook_mask_input) # type: ignore
|
||||
.to(device=predictor.model.device)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0) # shape: 1x1xHxW
|
||||
)
|
||||
|
||||
_, fb_dense_embeddings = predictor.model.prompt_encoder(
|
||||
points=None,
|
||||
boxes=None,
|
||||
masks=facebook_mask_input,
|
||||
)
|
||||
|
||||
_, scores, logits = sam_h.predict(truck, **one_prompt.__dict__)
|
||||
scores = scores.squeeze(0)
|
||||
mask_input = logits[:, scores.max(dim=0).indices, ...].unsqueeze(0) # shape: 1x1xHxW
|
||||
dense_embeddings = sam_h.mask_encoder(mask_input)
|
||||
|
||||
assert facebook_mask_input.shape == mask_input.shape
|
||||
assert torch.allclose(dense_embeddings, fb_dense_embeddings, atol=1e-4, rtol=1e-4)
|
||||
|
|
|
@ -63,8 +63,7 @@ class SAMPrompt:
|
|||
foreground_points: Sequence[tuple[float, float]] | None = None
|
||||
background_points: Sequence[tuple[float, float]] | None = None
|
||||
box_points: Sequence[Sequence[tuple[float, float]]] | None = None
|
||||
# TODO: support masks
|
||||
# masks: Sequence[Image.Image] | None = None
|
||||
low_res_mask: Tensor | None = None
|
||||
|
||||
def facebook_predict_kwargs(self) -> dict[str, NDArray]:
|
||||
prompt: dict[str, NDArray] = {}
|
||||
|
@ -85,13 +84,18 @@ class SAMPrompt:
|
|||
prompt["box"] = np.array([coord for batch in self.box_points for xy in batch for coord in xy]).reshape(
|
||||
len(self.box_points), 4
|
||||
)
|
||||
if self.low_res_mask is not None:
|
||||
prompt["mask_input"] = np.array(self.low_res_mask)
|
||||
return prompt
|
||||
|
||||
def facebook_prompt_encoder_kwargs(self, device: torch.device | None = None):
|
||||
def facebook_prompt_encoder_kwargs(
|
||||
self, device: torch.device | None = None
|
||||
) -> dict[str, Tensor | tuple[Tensor, Tensor | None] | None]:
|
||||
prompt = self.facebook_predict_kwargs()
|
||||
coords: Tensor | None = None
|
||||
labels: Tensor | None = None
|
||||
boxes: Tensor | None = None
|
||||
masks: Tensor | None = None
|
||||
if "point_coords" in prompt:
|
||||
coords = torch.as_tensor(prompt["point_coords"], dtype=torch.float, device=device).unsqueeze(0)
|
||||
if "point_labels" in prompt:
|
||||
|
@ -99,8 +103,9 @@ class SAMPrompt:
|
|||
if "box" in prompt:
|
||||
boxes = torch.as_tensor(prompt["box"], dtype=torch.float, device=device).unsqueeze(0)
|
||||
points = (coords, labels) if coords is not None else None
|
||||
# TODO: support masks
|
||||
return {"points": points, "boxes": boxes, "masks": None}
|
||||
if "mask_input" in prompt:
|
||||
masks = torch.as_tensor(prompt["mask_input"], dtype=torch.float, device=device).unsqueeze(0)
|
||||
return {"points": points, "boxes": boxes, "masks": masks}
|
||||
|
||||
|
||||
def intersection_over_union(
|
||||
|
|
Loading…
Reference in a new issue