from math import isclose from pathlib import Path from typing import cast import numpy as np import pytest import torch import torch.nn as nn from PIL import Image from tests.foundationals.segment_anything.utils import ( FacebookSAM, FacebookSAMPredictor, SAMPrompt, intersection_over_union, ) from torch import Tensor import refiners.fluxion.layers as fl from refiners.conversion.model_converter import ModelConverter from refiners.fluxion import manual_seed from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer # See predictor_example.ipynb official notebook PROMPTS: list[SAMPrompt] = [ SAMPrompt(foreground_points=((500, 375),)), SAMPrompt(background_points=((500, 375),)), SAMPrompt(foreground_points=((500, 375), (1125, 625))), SAMPrompt(foreground_points=((500, 375),), background_points=((1125, 625),)), SAMPrompt(box_points=[[(425, 600), (700, 875)]]), SAMPrompt(box_points=[[(425, 600), (700, 875)]], background_points=((575, 750),)), ] @pytest.fixture(params=PROMPTS) def prompt(request: pytest.FixtureRequest) -> SAMPrompt: return request.param @pytest.fixture def one_prompt() -> SAMPrompt: # Using the third prompt of the PROMPTS list in order to strictly do the same test as the official notebook in the # test_predictor_dense_mask test. return PROMPTS[2] @pytest.fixture(scope="module") def facebook_sam_h(sam_h_unconverted_weights_path: Path, test_device: torch.device) -> FacebookSAM: from segment_anything import build_sam_vit_h # type: ignore sam_h = cast(FacebookSAM, build_sam_vit_h()) sam_h.load_state_dict(state_dict=load_tensors(sam_h_unconverted_weights_path)) return sam_h.to(device=test_device) @pytest.fixture(scope="module") def facebook_sam_h_predictor(facebook_sam_h: FacebookSAM) -> FacebookSAMPredictor: from segment_anything import SamPredictor # type: ignore from segment_anything.modeling import Sam # type: ignore predictor = SamPredictor(cast(Sam, facebook_sam_h)) # type: ignore return cast(FacebookSAMPredictor, predictor) @pytest.fixture(scope="module") def sam_h(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH: sam_h = SegmentAnythingH(device=test_device) sam_h.load_from_safetensors(tensors_path=sam_h_weights_path) return sam_h @pytest.fixture(scope="module") def sam_h_single_output(sam_h_weights_path: Path, test_device: torch.device) -> SegmentAnythingH: sam_h = SegmentAnythingH(multimask_output=False, device=test_device) sam_h.load_from_safetensors(tensors_path=sam_h_weights_path) return sam_h @pytest.fixture(scope="module") def truck(ref_path: Path) -> Image.Image: return Image.open(ref_path / "truck.jpg").convert("RGB") # type: ignore @no_grad() def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None: manual_seed(seed=0) x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device) attention = cast(nn.Module, facebook_sam_h.image_encoder.blocks[0].attn) refiners_attention = FusedSelfAttention( embedding_dim=1280, num_heads=16, spatial_size=(14, 14), device=facebook_sam_h.device ) rpa = refiners_attention.layer("RelativePositionAttention", RelativePositionAttention) linear_1 = refiners_attention.layer("Linear_1", fl.Linear) linear_2 = refiners_attention.layer("Linear_2", fl.Linear) linear_1.weight = attention.qkv.weight linear_1.bias = attention.qkv.bias linear_2.weight = attention.proj.weight linear_2.bias = attention.proj.bias rpa.horizontal_embedding = attention.rel_pos_w rpa.vertical_embedding = attention.rel_pos_h y_1 = attention(x) assert y_1.shape == x.shape y_2 = refiners_attention(x) assert y_2.shape == x.shape assert torch.equal(input=y_1, other=y_2) def test_mask_decoder_arg() -> None: mask_decoder_default = MaskDecoder() sam_h = SegmentAnythingH(mask_decoder=mask_decoder_default) assert sam_h.mask_decoder == mask_decoder_default def test_multimask_output_error() -> None: mask_decoder_multimask_output = MaskDecoder(multimask_output=True) with pytest.raises(AssertionError, match="multimask_output"): SegmentAnythingH(mask_decoder=mask_decoder_multimask_output, multimask_output=False) @no_grad() def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None: resized = truck.resize(size=(1024, 1024)) # type: ignore image_tensor = image_to_tensor(image=resized, device=facebook_sam_h.device) y_1 = facebook_sam_h.image_encoder(image_tensor) y_2 = sam_h.image_encoder(image_tensor) assert torch.allclose(input=y_1, other=y_2, atol=1e-4) @no_grad() def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None: facebook_prompt_encoder = facebook_sam_h.prompt_encoder refiners_prompt_encoder = sam_h.point_encoder facebook_dense_pe: Tensor = cast(Tensor, facebook_prompt_encoder.get_dense_pe()) # type: ignore refiners_dense_pe = refiners_prompt_encoder.get_dense_positional_embedding(image_embedding_size=(64, 64)) assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe) @no_grad() def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None: facebook_prompt_encoder = facebook_sam_h.prompt_encoder refiners_prompt_encoder = sam_h.mask_encoder _, facebook_dense_pe = facebook_prompt_encoder(points=None, boxes=None, masks=None) refiners_dense_pe = refiners_prompt_encoder.get_no_mask_dense_embedding(image_embedding_size=(64, 64)) assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe) @no_grad() def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, prompt: SAMPrompt) -> None: facebook_prompt_encoder = facebook_sam_h.prompt_encoder refiners_prompt_encoder = sam_h.point_encoder facebook_sparse_pe, _ = facebook_prompt_encoder( **prompt.facebook_prompt_encoder_kwargs(device=facebook_sam_h.device) ) prompt_dict = prompt.__dict__ # Skip mask prompt, if any, since the point encoder only consumes points and boxes # TODO: split `SAMPrompt` and introduce a dedicated one for dense prompts prompt_dict.pop("low_res_mask", None) assert prompt_dict is not None, "`test_point_encoder` cannot be called with just a `low_res_mask`" coordinates, type_mask = refiners_prompt_encoder.points_to_tensor(**prompt_dict) # Shift to center of pixel + normalize in [0, 1] (see `_embed_points` in segment-anything official repo) coordinates[:, :, 0] = (coordinates[:, :, 0] + 0.5) / 1024.0 coordinates[:, :, 1] = (coordinates[:, :, 1] + 0.5) / 1024.0 refiners_prompt_encoder.set_type_mask(type_mask=type_mask) refiners_sparse_pe = refiners_prompt_encoder(coordinates) assert torch.equal(input=refiners_sparse_pe, other=facebook_sparse_pe) @no_grad() def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None: dense_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device) dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device) sparse_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device) refiners_layer = TwoWayTransformerLayer( embedding_dim=256, feed_forward_dim=2048, num_heads=8, device=facebook_sam_h.device ) facebook_layer = facebook_sam_h.mask_decoder.transformer.layers[1] # type: ignore assert isinstance(facebook_layer, nn.Module) refiners_layer.set_context( context="mask_decoder", value={ "dense_embedding": dense_embedding, "dense_positional_embedding": dense_positional_embedding, "sparse_embedding": sparse_embedding, }, ) facebook_inputs = { "queries": sparse_embedding, "keys": dense_embedding, "query_pe": sparse_embedding, "key_pe": dense_positional_embedding, } converter = ModelConverter( source_model=facebook_layer, target_model=refiners_layer, skip_output_check=True, # done below, manually ) assert converter.run(source_args=facebook_inputs, target_args=(sparse_embedding,)) refiners_layer.set_context( context="mask_decoder", value={ "dense_embedding": dense_embedding, "dense_positional_embedding": dense_positional_embedding, "sparse_embedding": sparse_embedding, }, ) y_1 = facebook_layer(**facebook_inputs)[0] y_2 = refiners_layer(sparse_embedding)[0] assert torch.equal(input=y_1, other=y_2) @no_grad() def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None: manual_seed(seed=0) facebook_mask_decoder = facebook_sam_h.mask_decoder refiners_mask_decoder = sam_h.mask_decoder image_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device) dense_positional_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device) point_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device) mask_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device) from segment_anything.modeling.common import LayerNorm2d # type: ignore assert issubclass(LayerNorm2d, nn.Module) custom_layers = {LayerNorm2d: fl.LayerNorm2d} converter = ModelConverter( source_model=facebook_mask_decoder, target_model=refiners_mask_decoder, custom_layer_mapping=custom_layers, # type: ignore ) inputs = { "image_embeddings": image_embedding, "image_pe": dense_positional_embedding, "sparse_prompt_embeddings": point_embedding, "dense_prompt_embeddings": mask_embedding, "multimask_output": True, } refiners_mask_decoder.set_image_embedding(image_embedding) refiners_mask_decoder.set_point_embedding(point_embedding) refiners_mask_decoder.set_mask_embedding(mask_embedding) refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding) mapping = converter.map_state_dicts(source_args=inputs, target_args={}) assert mapping is not None mapping["MaskDecoderTokens.Parameter"] = "iou_token" state_dict = converter._convert_state_dict( # type: ignore source_state_dict=facebook_mask_decoder.state_dict(), target_state_dict=refiners_mask_decoder.state_dict(), state_dict_mapping=mapping, ) state_dict["MaskDecoderTokens.Parameter.weight"] = torch.cat( [facebook_mask_decoder.iou_token.weight, facebook_mask_decoder.mask_tokens.weight], dim=0 ) # type: ignore refiners_mask_decoder.load_state_dict(state_dict=state_dict) facebook_output = facebook_mask_decoder(**inputs) refiners_mask_decoder.set_image_embedding(image_embedding) refiners_mask_decoder.set_point_embedding(point_embedding) refiners_mask_decoder.set_mask_embedding(mask_embedding) refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding) mask_prediction, iou_prediction = refiners_mask_decoder() facebook_masks = facebook_output[0] facebook_prediction = facebook_output[1] assert torch.equal(input=mask_prediction, other=facebook_masks) assert torch.equal(input=iou_prediction, other=facebook_prediction) def test_predictor( facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, prompt: SAMPrompt ) -> None: predictor = facebook_sam_h_predictor predictor.set_image(np.array(truck)) facebook_masks, facebook_scores, _ = predictor.predict(**prompt.facebook_predict_kwargs()) # type: ignore assert len(facebook_masks) == 3 masks, scores, _ = sam_h.predict(truck, **prompt.__dict__) masks = masks.squeeze(0) scores = scores.squeeze(0) assert len(masks) == 3 for i in range(3): mask_prediction = masks[i].cpu() facebook_mask = torch.as_tensor(facebook_masks[i]) iou = intersection_over_union(mask_prediction, facebook_mask) assert isclose(iou, 1.0, rel_tol=5e-04), f"iou: {iou}" assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-04) def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None: masks_ref, scores_ref, _ = sam_h.predict(truck, **one_prompt.__dict__) image_embedding = sam_h.compute_image_embedding(truck) masks, scores, _ = sam_h.predict(image_embedding, **one_prompt.__dict__) assert torch.equal(masks, masks_ref) assert torch.equal(scores_ref, scores) def test_predictor_dense_mask( facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt ) -> None: """ NOTE : Binarizing intermediate masks isn't necessary, as per SamPredictor.predict_torch docstring: > mask_input (np.ndarray): A low resolution mask input to the model, typically > coming from a previous prediction iteration. Has form Bx1xHxW, where > for SAM, H=W=256. Masks returned by a previous iteration of the > predict method do not need further transformation. """ predictor = facebook_sam_h_predictor predictor.set_image(np.array(truck)) facebook_masks, facebook_scores, facebook_logits = predictor.predict( **one_prompt.facebook_predict_kwargs(), # type: ignore multimask_output=True, ) assert len(facebook_masks) == 3 facebook_mask_input = facebook_logits[np.argmax(facebook_scores)] # shape: HxW # Using the same mask coordinates inputs as the official notebook facebook_prompt = SAMPrompt( foreground_points=((500, 375),), background_points=((1125, 625),), low_res_mask=facebook_mask_input[None, ...] ) facebook_dense_masks, _, _ = predictor.predict(**facebook_prompt.facebook_predict_kwargs(), multimask_output=True) # type: ignore assert len(facebook_dense_masks) == 3 masks, scores, logits = sam_h.predict(truck, **one_prompt.__dict__) masks = masks.squeeze(0) scores = scores.squeeze(0) assert len(masks) == 3 mask_input = logits[:, scores.max(dim=0).indices, ...] # shape: 1xHxW assert np.allclose( mask_input.cpu(), facebook_mask_input, atol=1e-1 ) # Lower doesn't pass, but it's close enough for logits refiners_prompt = SAMPrompt( foreground_points=((500, 375),), background_points=((1125, 625),), low_res_mask=mask_input.unsqueeze(0) ) dense_masks, _, _ = sam_h.predict(truck, **refiners_prompt.__dict__) dense_masks = dense_masks.squeeze(0) assert len(dense_masks) == 3 for i in range(3): dense_mask_prediction = dense_masks[i].cpu() facebook_dense_mask = torch.as_tensor(facebook_dense_masks[i]) assert dense_mask_prediction.shape == facebook_dense_mask.shape assert isclose(intersection_over_union(dense_mask_prediction, facebook_dense_mask), 1.0, rel_tol=5e-05) def test_predictor_single_output( facebook_sam_h_predictor: FacebookSAMPredictor, sam_h_single_output: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt, ) -> None: predictor = facebook_sam_h_predictor predictor.set_image(np.array(truck)) facebook_masks, facebook_scores, facebook_low_res_masks = predictor.predict( # type: ignore **one_prompt.facebook_predict_kwargs(), # type: ignore multimask_output=False, ) assert len(facebook_masks) == 1 masks, scores, low_res_masks = sam_h_single_output.predict(truck, **one_prompt.__dict__) masks = masks.squeeze(0) scores = scores.squeeze(0) assert len(masks) == 1 assert torch.allclose( low_res_masks[0, 0, ...], torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device), atol=5e-2, # see test_predictor_resized_single_output for more explanation ) assert isclose(scores[0].item(), facebook_scores[0].item(), abs_tol=1e-05) mask_prediction = masks[0].cpu() facebook_mask = torch.as_tensor(facebook_masks[0]) assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05) def test_predictor_resized_single_output( facebook_sam_h_predictor: FacebookSAMPredictor, sam_h_single_output: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt, ) -> None: # The refiners implementation of SAM differs from official # implementation by a 6e-3 absolute diff (see test_predictor_single_output) # This diff is related to 2 components : # * image_encoder (see test_image_encoder) # * point rescaling (facebook uses numpy while refiners uses torch) # # Current test is designed to workaround those 2 components # * facebook image_embedding is used # * the image is pre-resized by (1024, 1024) so there is no rescaling # Then the test pass with torch.equal predictor = facebook_sam_h_predictor size = (1024, 1024) resized_truck = truck.resize(size) # type: ignore predictor.set_image(np.array(resized_truck)) _, _, facebook_low_res_masks = predictor.predict( # type: ignore **one_prompt.facebook_predict_kwargs(), # type: ignore multimask_output=False, ) facebook_image_embedding = ImageEmbedding(features=predictor.features, original_image_size=size) _, _, low_res_masks = sam_h_single_output.predict(facebook_image_embedding, **one_prompt.__dict__) assert torch.equal( low_res_masks[0, 0, ...], torch.as_tensor(facebook_low_res_masks[0], device=sam_h_single_output.device), ) def test_mask_encoder( facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt, ) -> None: predictor = facebook_sam_h_predictor predictor.set_image(np.array(truck)) _, facebook_scores, facebook_logits = predictor.predict( **one_prompt.facebook_predict_kwargs(), # type: ignore multimask_output=True, ) facebook_mask_input = facebook_logits[np.argmax(facebook_scores)] facebook_mask_input = ( torch.from_numpy(facebook_mask_input) # type: ignore .to(device=predictor.model.device) .unsqueeze(0) .unsqueeze(0) # shape: 1x1xHxW ) _, fb_dense_embeddings = predictor.model.prompt_encoder( points=None, boxes=None, masks=facebook_mask_input, ) _, scores, logits = sam_h.predict(truck, **one_prompt.__dict__) scores = scores.squeeze(0) mask_input = logits[:, scores.max(dim=0).indices, ...].unsqueeze(0) # shape: 1x1xHxW dense_embeddings = sam_h.mask_encoder(mask_input) assert facebook_mask_input.shape == mask_input.shape assert torch.allclose(dense_embeddings, fb_dense_embeddings, atol=1e-3) @no_grad() def test_batch_mask_decoder(sam_h: SegmentAnythingH) -> None: batch_size = 5 image_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1) mask_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1) dense_positional_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat( batch_size, 1, 1, 1 ) point_embedding = torch.randn(1, 2, 256, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1) sam_h.mask_decoder.set_image_embedding(image_embedding) sam_h.mask_decoder.set_mask_embedding(mask_embedding) sam_h.mask_decoder.set_point_embedding(point_embedding) sam_h.mask_decoder.set_dense_positional_embedding(dense_positional_embedding) mask_prediction, iou_prediction = sam_h.mask_decoder() assert mask_prediction.shape == (batch_size, 3, 256, 256) assert iou_prediction.shape == (batch_size, 3) assert torch.equal(mask_prediction[0], mask_prediction[1])