mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
SegmentAnything: add dense mask prompt support
This commit is contained in:
parent
20c229903f
commit
00f494efe2
|
@ -37,13 +37,36 @@ class Args(argparse.Namespace):
|
||||||
|
|
||||||
|
|
||||||
def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
|
def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
|
||||||
|
manual_seed(seed=0)
|
||||||
|
refiners_mask_encoder = MaskEncoder()
|
||||||
|
|
||||||
|
converter = ModelConverter(
|
||||||
|
source_model=prompt_encoder.mask_downscaling,
|
||||||
|
target_model=refiners_mask_encoder,
|
||||||
|
custom_layer_mapping=custom_layers, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.randn(1, 256, 256)
|
||||||
|
mapping = converter.map_state_dicts(source_args=(x,))
|
||||||
|
assert mapping
|
||||||
|
|
||||||
|
source_state_dict = prompt_encoder.mask_downscaling.state_dict()
|
||||||
|
target_state_dict = refiners_mask_encoder.state_dict()
|
||||||
|
|
||||||
|
# Mapping handled manually (see below) because nn.Parameter is a special case
|
||||||
|
del target_state_dict["no_mask_embedding"]
|
||||||
|
|
||||||
|
converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
|
||||||
|
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
||||||
|
)
|
||||||
|
|
||||||
state_dict: dict[str, Tensor] = {
|
state_dict: dict[str, Tensor] = {
|
||||||
"no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore
|
"no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore
|
||||||
}
|
}
|
||||||
|
|
||||||
refiners_mask_encoder = MaskEncoder()
|
state_dict.update(converted_source)
|
||||||
# TODO: handle other weights
|
|
||||||
refiners_mask_encoder.load_state_dict(state_dict=state_dict, strict=False)
|
refiners_mask_encoder.load_state_dict(state_dict=state_dict)
|
||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
|
@ -3,11 +3,12 @@ from typing import Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from jaxtyping import Float
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad
|
from refiners.fluxion.utils import interpolate, no_grad, normalize, pad
|
||||||
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
|
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
|
||||||
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||||
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
||||||
|
@ -55,7 +56,7 @@ class SegmentAnything(fl.Module):
|
||||||
foreground_points: Sequence[tuple[float, float]] | None = None,
|
foreground_points: Sequence[tuple[float, float]] | None = None,
|
||||||
background_points: Sequence[tuple[float, float]] | None = None,
|
background_points: Sequence[tuple[float, float]] | None = None,
|
||||||
box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
|
box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
|
||||||
masks: Sequence[Image.Image] | None = None,
|
low_res_mask: Float[Tensor, "1 1 256 256"] | None = None,
|
||||||
binarize: bool = True,
|
binarize: bool = True,
|
||||||
) -> tuple[Tensor, Tensor, Tensor]:
|
) -> tuple[Tensor, Tensor, Tensor]:
|
||||||
if isinstance(input, ImageEmbedding):
|
if isinstance(input, ImageEmbedding):
|
||||||
|
@ -74,15 +75,13 @@ class SegmentAnything(fl.Module):
|
||||||
)
|
)
|
||||||
self.point_encoder.set_type_mask(type_mask=type_mask)
|
self.point_encoder.set_type_mask(type_mask=type_mask)
|
||||||
|
|
||||||
if masks is not None:
|
if low_res_mask is not None:
|
||||||
mask_tensor = torch.stack(
|
mask_embedding = self.mask_encoder(low_res_mask)
|
||||||
tensors=[image_to_tensor(image=mask, device=self.device, dtype=self.dtype) for mask in masks]
|
|
||||||
)
|
|
||||||
mask_embedding = self.mask_encoder(mask_tensor)
|
|
||||||
else:
|
else:
|
||||||
mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(
|
mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(
|
||||||
image_embedding_size=self.image_encoder.image_embedding_size
|
image_embedding_size=self.image_encoder.image_embedding_size
|
||||||
)
|
)
|
||||||
|
|
||||||
point_embedding = self.point_encoder(
|
point_embedding = self.point_encoder(
|
||||||
self.normalize(coordinates, target_size=target_size, original_size=original_size)
|
self.normalize(coordinates, target_size=target_size, original_size=original_size)
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,7 +23,7 @@ from refiners.foundationals.segment_anything.image_encoder import FusedSelfAtten
|
||||||
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
||||||
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
|
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
|
||||||
|
|
||||||
# See predictor_example.ipynb official notebook (note: mask_input is not yet properly supported)
|
# See predictor_example.ipynb official notebook
|
||||||
PROMPTS: list[SAMPrompt] = [
|
PROMPTS: list[SAMPrompt] = [
|
||||||
SAMPrompt(foreground_points=((500, 375),)),
|
SAMPrompt(foreground_points=((500, 375),)),
|
||||||
SAMPrompt(background_points=((500, 375),)),
|
SAMPrompt(background_points=((500, 375),)),
|
||||||
|
@ -41,7 +41,9 @@ def prompt(request: pytest.FixtureRequest) -> SAMPrompt:
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def one_prompt() -> SAMPrompt:
|
def one_prompt() -> SAMPrompt:
|
||||||
return PROMPTS[0]
|
# Using the third prompt of the PROMPTS list in order to strictly do the same test as the official notebook in the
|
||||||
|
# test_predictor_dense_mask test.
|
||||||
|
return PROMPTS[2]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -83,8 +85,7 @@ def facebook_sam_h_predictor(facebook_sam_h: FacebookSAM) -> FacebookSAMPredicto
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
|
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
|
||||||
sam_h = SegmentAnythingH(device=test_device)
|
sam_h = SegmentAnythingH(device=test_device)
|
||||||
# TODO: make strict=True when the MasKEncoder conversion is done
|
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
|
||||||
sam_h.load_from_safetensors(tensors_path=sam_h_weights, strict=False)
|
|
||||||
return sam_h
|
return sam_h
|
||||||
|
|
||||||
|
|
||||||
|
@ -164,7 +165,14 @@ def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, pro
|
||||||
**prompt.facebook_prompt_encoder_kwargs(device=facebook_sam_h.device)
|
**prompt.facebook_prompt_encoder_kwargs(device=facebook_sam_h.device)
|
||||||
)
|
)
|
||||||
|
|
||||||
coordinates, type_mask = refiners_prompt_encoder.points_to_tensor(**prompt.__dict__)
|
prompt_dict = prompt.__dict__
|
||||||
|
# Skip mask prompt, if any, since the point encoder only consumes points and boxes
|
||||||
|
# TODO: split `SAMPrompt` and introduce a dedicated one for dense prompts
|
||||||
|
prompt_dict.pop("low_res_mask", None)
|
||||||
|
|
||||||
|
assert prompt_dict is not None, "`test_point_encoder` cannot be called with just a `low_res_mask`"
|
||||||
|
|
||||||
|
coordinates, type_mask = refiners_prompt_encoder.points_to_tensor(**prompt_dict)
|
||||||
# Shift to center of pixel + normalize in [0, 1] (see `_embed_points` in segment-anything official repo)
|
# Shift to center of pixel + normalize in [0, 1] (see `_embed_points` in segment-anything official repo)
|
||||||
coordinates[:, :, 0] = (coordinates[:, :, 0] + 0.5) / 1024.0
|
coordinates[:, :, 0] = (coordinates[:, :, 0] + 0.5) / 1024.0
|
||||||
coordinates[:, :, 1] = (coordinates[:, :, 1] + 0.5) / 1024.0
|
coordinates[:, :, 1] = (coordinates[:, :, 1] + 0.5) / 1024.0
|
||||||
|
@ -319,3 +327,91 @@ def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image,
|
||||||
|
|
||||||
assert torch.equal(masks, masks_ref)
|
assert torch.equal(masks, masks_ref)
|
||||||
assert torch.equal(scores_ref, scores)
|
assert torch.equal(scores_ref, scores)
|
||||||
|
|
||||||
|
|
||||||
|
def test_predictor_dense_mask(
|
||||||
|
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
NOTE : Binarizing intermediate masks isn't necessary, as per SamPredictor.predict_torch docstring:
|
||||||
|
> mask_input (np.ndarray): A low resolution mask input to the model, typically
|
||||||
|
> coming from a previous prediction iteration. Has form Bx1xHxW, where
|
||||||
|
> for SAM, H=W=256. Masks returned by a previous iteration of the
|
||||||
|
> predict method do not need further transformation.
|
||||||
|
"""
|
||||||
|
predictor = facebook_sam_h_predictor
|
||||||
|
predictor.set_image(np.array(truck))
|
||||||
|
facebook_masks, facebook_scores, facebook_logits = predictor.predict(
|
||||||
|
**one_prompt.facebook_predict_kwargs(), # type: ignore
|
||||||
|
multimask_output=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(facebook_masks) == 3
|
||||||
|
|
||||||
|
facebook_mask_input = facebook_logits[np.argmax(facebook_scores)] # shape: HxW
|
||||||
|
|
||||||
|
# Using the same mask coordinates inputs as the official notebook
|
||||||
|
facebook_prompt = SAMPrompt(
|
||||||
|
foreground_points=((500, 375),), background_points=((1125, 625),), low_res_mask=facebook_mask_input[None, ...]
|
||||||
|
)
|
||||||
|
facebook_dense_masks, _, _ = predictor.predict(**facebook_prompt.facebook_predict_kwargs(), multimask_output=True) # type: ignore
|
||||||
|
|
||||||
|
assert len(facebook_dense_masks) == 3
|
||||||
|
|
||||||
|
masks, scores, logits = sam_h.predict(truck, **one_prompt.__dict__)
|
||||||
|
masks = masks.squeeze(0)
|
||||||
|
scores = scores.squeeze(0)
|
||||||
|
|
||||||
|
assert len(masks) == 3
|
||||||
|
|
||||||
|
mask_input = logits[:, scores.max(dim=0).indices, ...] # shape: 1xHxW
|
||||||
|
|
||||||
|
assert np.allclose(
|
||||||
|
mask_input.cpu().numpy(), facebook_mask_input, atol=1e-1
|
||||||
|
) # Lower doesn't pass, but it's close enough for logits
|
||||||
|
|
||||||
|
refiners_prompt = SAMPrompt(
|
||||||
|
foreground_points=((500, 375),), background_points=((1125, 625),), low_res_mask=mask_input.unsqueeze(0)
|
||||||
|
)
|
||||||
|
dense_masks, _, _ = sam_h.predict(truck, **refiners_prompt.__dict__)
|
||||||
|
dense_masks = dense_masks.squeeze(0)
|
||||||
|
|
||||||
|
assert len(dense_masks) == 3
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
dense_mask_prediction = dense_masks[i].cpu()
|
||||||
|
facebook_dense_mask = torch.as_tensor(facebook_dense_masks[i])
|
||||||
|
assert dense_mask_prediction.shape == facebook_dense_mask.shape
|
||||||
|
assert isclose(intersection_over_union(dense_mask_prediction, facebook_dense_mask), 1.0, rel_tol=5e-05)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mask_encoder(
|
||||||
|
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt
|
||||||
|
) -> None:
|
||||||
|
predictor = facebook_sam_h_predictor
|
||||||
|
predictor.set_image(np.array(truck))
|
||||||
|
_, facebook_scores, facebook_logits = predictor.predict(
|
||||||
|
**one_prompt.facebook_predict_kwargs(), # type: ignore
|
||||||
|
multimask_output=True,
|
||||||
|
)
|
||||||
|
facebook_mask_input = facebook_logits[np.argmax(facebook_scores)]
|
||||||
|
facebook_mask_input = (
|
||||||
|
torch.from_numpy(facebook_mask_input) # type: ignore
|
||||||
|
.to(device=predictor.model.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.unsqueeze(0) # shape: 1x1xHxW
|
||||||
|
)
|
||||||
|
|
||||||
|
_, fb_dense_embeddings = predictor.model.prompt_encoder(
|
||||||
|
points=None,
|
||||||
|
boxes=None,
|
||||||
|
masks=facebook_mask_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, scores, logits = sam_h.predict(truck, **one_prompt.__dict__)
|
||||||
|
scores = scores.squeeze(0)
|
||||||
|
mask_input = logits[:, scores.max(dim=0).indices, ...].unsqueeze(0) # shape: 1x1xHxW
|
||||||
|
dense_embeddings = sam_h.mask_encoder(mask_input)
|
||||||
|
|
||||||
|
assert facebook_mask_input.shape == mask_input.shape
|
||||||
|
assert torch.allclose(dense_embeddings, fb_dense_embeddings, atol=1e-4, rtol=1e-4)
|
||||||
|
|
|
@ -63,8 +63,7 @@ class SAMPrompt:
|
||||||
foreground_points: Sequence[tuple[float, float]] | None = None
|
foreground_points: Sequence[tuple[float, float]] | None = None
|
||||||
background_points: Sequence[tuple[float, float]] | None = None
|
background_points: Sequence[tuple[float, float]] | None = None
|
||||||
box_points: Sequence[Sequence[tuple[float, float]]] | None = None
|
box_points: Sequence[Sequence[tuple[float, float]]] | None = None
|
||||||
# TODO: support masks
|
low_res_mask: Tensor | None = None
|
||||||
# masks: Sequence[Image.Image] | None = None
|
|
||||||
|
|
||||||
def facebook_predict_kwargs(self) -> dict[str, NDArray]:
|
def facebook_predict_kwargs(self) -> dict[str, NDArray]:
|
||||||
prompt: dict[str, NDArray] = {}
|
prompt: dict[str, NDArray] = {}
|
||||||
|
@ -85,13 +84,18 @@ class SAMPrompt:
|
||||||
prompt["box"] = np.array([coord for batch in self.box_points for xy in batch for coord in xy]).reshape(
|
prompt["box"] = np.array([coord for batch in self.box_points for xy in batch for coord in xy]).reshape(
|
||||||
len(self.box_points), 4
|
len(self.box_points), 4
|
||||||
)
|
)
|
||||||
|
if self.low_res_mask is not None:
|
||||||
|
prompt["mask_input"] = np.array(self.low_res_mask)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def facebook_prompt_encoder_kwargs(self, device: torch.device | None = None):
|
def facebook_prompt_encoder_kwargs(
|
||||||
|
self, device: torch.device | None = None
|
||||||
|
) -> dict[str, Tensor | tuple[Tensor, Tensor | None] | None]:
|
||||||
prompt = self.facebook_predict_kwargs()
|
prompt = self.facebook_predict_kwargs()
|
||||||
coords: Tensor | None = None
|
coords: Tensor | None = None
|
||||||
labels: Tensor | None = None
|
labels: Tensor | None = None
|
||||||
boxes: Tensor | None = None
|
boxes: Tensor | None = None
|
||||||
|
masks: Tensor | None = None
|
||||||
if "point_coords" in prompt:
|
if "point_coords" in prompt:
|
||||||
coords = torch.as_tensor(prompt["point_coords"], dtype=torch.float, device=device).unsqueeze(0)
|
coords = torch.as_tensor(prompt["point_coords"], dtype=torch.float, device=device).unsqueeze(0)
|
||||||
if "point_labels" in prompt:
|
if "point_labels" in prompt:
|
||||||
|
@ -99,8 +103,9 @@ class SAMPrompt:
|
||||||
if "box" in prompt:
|
if "box" in prompt:
|
||||||
boxes = torch.as_tensor(prompt["box"], dtype=torch.float, device=device).unsqueeze(0)
|
boxes = torch.as_tensor(prompt["box"], dtype=torch.float, device=device).unsqueeze(0)
|
||||||
points = (coords, labels) if coords is not None else None
|
points = (coords, labels) if coords is not None else None
|
||||||
# TODO: support masks
|
if "mask_input" in prompt:
|
||||||
return {"points": points, "boxes": boxes, "masks": None}
|
masks = torch.as_tensor(prompt["mask_input"], dtype=torch.float, device=device).unsqueeze(0)
|
||||||
|
return {"points": points, "boxes": boxes, "masks": masks}
|
||||||
|
|
||||||
|
|
||||||
def intersection_over_union(
|
def intersection_over_union(
|
||||||
|
|
Loading…
Reference in a new issue