diff --git a/pyproject.toml b/pyproject.toml index aa9b9f7..b6bb35a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,11 @@ test = [ # An unofficial Python package for Meta AI's Segment Anything Model: # https://github.com/opengeos/segment-anything "segment-anything-py>=1.0", + # Official Python package for HQ-SAM + "segment-anything-hq>=0.3", + # HQ-SAM missing dependency: + # https://github.com/SysCV/sam-hq/pull/59 + "timm>=0.5.0", ] conversion = [ "diffusers>=0.26.1", @@ -140,3 +145,10 @@ exclude_also = [ [tool.typos.default] extend-words = { adaptee = "adaptee" } extend-ignore-identifiers-re = ["NDArray*"] + +[tool.pytest.ini_options] +filterwarnings = [ + "ignore::UserWarning:segment_anything_hq.modeling.tiny_vit_sam.*", + "ignore::DeprecationWarning:timm.models.layers.*", + "ignore::DeprecationWarning:timm.models.registry.*" +] diff --git a/requirements.lock b/requirements.lock index 8ba5a0d..93fdce7 100644 --- a/requirements.lock +++ b/requirements.lock @@ -255,6 +255,7 @@ safetensors==0.4.2 # via transformers scipy==1.12.0 # via bitsandbytes +segment-anything-hq==0.3 segment-anything-py==1.0 # via refiners sentry-sdk==1.40.6 @@ -270,6 +271,7 @@ smmap==5.0.1 # via gitdb sympy==1.12 # via torch +timm==0.9.16 tokenizers==0.15.2 # via transformers tomli==2.0.1 diff --git a/scripts/conversion/convert_hq_segment_anything.py b/scripts/conversion/convert_hq_segment_anything.py new file mode 100644 index 0000000..cbb36b7 --- /dev/null +++ b/scripts/conversion/convert_hq_segment_anything.py @@ -0,0 +1,81 @@ +import argparse + +from torch import Tensor + +from refiners.fluxion.utils import load_tensors, save_to_safetensors + + +def main() -> None: + parser = argparse.ArgumentParser(description="Convert HQ SAM model to Refiners state_dict format") + parser.add_argument( + "--from", + type=str, + dest="source_path", + required=True, + default="sam_hq_vit_h.pth", + help="Path to the source model checkpoint.", + ) + parser.add_argument( + "--to", + type=str, + dest="output_path", + required=True, + default="refiners_sam_hq_vit_h.safetensors", + help="Path to save the converted model in Refiners format.", + ) + args = parser.parse_args() + + source_state_dict = load_tensors(args.source_path) + + state_dict: dict[str, Tensor] = {} + + for suffix in ["weight", "bias"]: + state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_1.{suffix}"] = source_state_dict[ + f"mask_decoder.compress_vit_feat.0.{suffix}" + ] + state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_1.{suffix}"] = source_state_dict[ + f"mask_decoder.embedding_encoder.0.{suffix}" + ] + state_dict[f"EmbeddingMaskfeature.Conv2d_1.{suffix}"] = source_state_dict[ + f"mask_decoder.embedding_maskfeature.0.{suffix}" + ] + + state_dict[f"HQFeatures.CompressViTFeat.LayerNorm2d.{suffix}"] = source_state_dict[ + f"mask_decoder.compress_vit_feat.1.{suffix}" + ] + state_dict[f"HQFeatures.EmbeddingEncoder.LayerNorm2d.{suffix}"] = source_state_dict[ + f"mask_decoder.embedding_encoder.1.{suffix}" + ] + state_dict[f"EmbeddingMaskfeature.LayerNorm2d.{suffix}"] = source_state_dict[ + f"mask_decoder.embedding_maskfeature.1.{suffix}" + ] + + state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_2.{suffix}"] = source_state_dict[ + f"mask_decoder.compress_vit_feat.3.{suffix}" + ] + state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_2.{suffix}"] = source_state_dict[ + f"mask_decoder.embedding_encoder.3.{suffix}" + ] + state_dict[f"EmbeddingMaskfeature.Conv2d_2.{suffix}"] = source_state_dict[ + f"mask_decoder.embedding_maskfeature.3.{suffix}" + ] + + state_dict = {f"Chain.HQSAMMaskPrediction.Chain.DenseEmbeddingUpscalingHQ.{k}": v for k, v in state_dict.items()} + + # HQ Token + state_dict["MaskDecoderTokensExtender.hq_token.weight"] = source_state_dict["mask_decoder.hf_token.weight"] + + # HQ MLP + for i in range(3): + state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.weight"] = source_state_dict[ + f"mask_decoder.hf_mlp.layers.{i}.weight" + ] + state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.bias"] = source_state_dict[ + f"mask_decoder.hf_mlp.layers.{i}.bias" + ] + + save_to_safetensors(path=args.output_path, tensors=state_dict) + + +if __name__ == "__main__": + main() diff --git a/scripts/conversion/convert_segment_anything.py b/scripts/conversion/convert_segment_anything.py index 7ce70ce..227c58a 100644 --- a/scripts/conversion/convert_segment_anything.py +++ b/scripts/conversion/convert_segment_anything.py @@ -185,17 +185,16 @@ def convert_mask_decoder(mask_decoder: nn.Module) -> dict[str, Tensor]: mapping = converter.map_state_dicts(source_args=inputs, target_args={}) assert mapping is not None - mapping["IOUMaskEncoder"] = "iou_token" + mapping["MaskDecoderTokens.Parameter"] = "iou_token" state_dict = converter._convert_state_dict( # type: ignore source_state_dict=mask_decoder.state_dict(), target_state_dict=refiners_mask_decoder.state_dict(), state_dict_mapping=mapping, ) - state_dict["IOUMaskEncoder.weight"] = torch.cat( + state_dict["MaskDecoderTokens.Parameter.weight"] = torch.cat( tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0 ) # type: ignore - refiners_mask_decoder.load_state_dict(state_dict=state_dict) refiners_mask_decoder.set_image_embedding(image_embedding) @@ -254,10 +253,10 @@ def main() -> None: mask_encoder_state_dict = convert_mask_encoder(prompt_encoder=sam_h.prompt_encoder) output_state_dict = { - **{".".join(("image_encoder", key)): value for key, value in vit_state_dict.items()}, - **{".".join(("mask_decoder", key)): value for key, value in mask_decoder_state_dict.items()}, - **{".".join(("point_encoder", key)): value for key, value in point_encoder_state_dict.items()}, - **{".".join(("mask_encoder", key)): value for key, value in mask_encoder_state_dict.items()}, + **{f"SAMViTH.{key}": value for key, value in vit_state_dict.items()}, + **{f"MaskDecoder.{key}": value for key, value in mask_decoder_state_dict.items()}, + **{f"PointEncoder.{key}": value for key, value in point_encoder_state_dict.items()}, + **{f"MaskEncoder.{key}": value for key, value in mask_encoder_state_dict.items()}, } if args.half: output_state_dict = {key: value.half() for key, value in output_state_dict.items()} diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index 58213ac..6cde8a7 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -366,6 +366,13 @@ def download_sam(): ) +def download_hq_sam(): + weights_folder = os.path.join(test_weights_dir) + download_file( + "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth", weights_folder, expected_hash="66da2472" + ) + + def download_dinov2(): # For conversion weights_folder = os.path.join(test_weights_dir) @@ -661,7 +668,16 @@ def convert_sam(): "convert_segment_anything.py", "tests/weights/sam_vit_h_4b8939.pth", "tests/weights/segment-anything-h.safetensors", - expected_hash="b62ad5ed", + expected_hash="5ffb976f", + ) + + +def convert_hq_sam(): + run_conversion_script( + "convert_hq_segment_anything.py", + "tests/weights/sam_hq_vit_h.pth", + "tests/weights/refiners-sam-hq-vit-h.safetensors", + expected_hash="b2f5e79f", ) @@ -769,6 +785,7 @@ def download_all(): download_ip_adapter() download_t2i_adapter() download_sam() + download_hq_sam() download_dinov2() download_control_lora_fooocus() download_lcm_base() @@ -789,6 +806,7 @@ def convert_all(): convert_ip_adapter() convert_t2i_adapter() convert_sam() + convert_hq_sam() convert_dinov2() convert_control_lora_fooocus() convert_lcm_base() diff --git a/src/refiners/foundationals/segment_anything/hq_sam.py b/src/refiners/foundationals/segment_anything/hq_sam.py new file mode 100644 index 0000000..9d54aec --- /dev/null +++ b/src/refiners/foundationals/segment_anything/hq_sam.py @@ -0,0 +1,378 @@ +import torch +from torch import device as Device, dtype as DType + +import refiners.fluxion.layers as fl +from refiners.fluxion.adapters import Adapter +from refiners.fluxion.context import Contexts +from refiners.foundationals.segment_anything.image_encoder import SAMViT, TransformerLayer +from refiners.foundationals.segment_anything.mask_decoder import ( + MaskDecoderTokens, + MaskPrediction, + Predictions, +) +from refiners.foundationals.segment_anything.model import SegmentAnything + + +class CompressViTFeat(fl.Chain): + def __init__( + self, + transformer_dim: int = 256, + vit_dim: int = 1024, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + fl.UseContext(context="hq_sam", key="early_vit_embedding"), + fl.Permute(0, 3, 1, 2), + fl.ConvTranspose2d( + in_channels=vit_dim, + out_channels=transformer_dim, + kernel_size=2, + stride=2, + device=device, + dtype=dtype, + ), + fl.LayerNorm2d(channels=transformer_dim, device=device, dtype=dtype), + fl.GeLU(), + fl.ConvTranspose2d( + in_channels=transformer_dim, + out_channels=transformer_dim // 8, + kernel_size=2, + stride=2, + device=device, + dtype=dtype, + ), + ) + + +class EmbeddingEncoder(fl.Chain): + def __init__( + self, + transformer_dim: int = 256, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + fl.UseContext(context="mask_decoder", key="image_embedding"), + fl.ConvTranspose2d( + in_channels=transformer_dim, + out_channels=transformer_dim // 4, + kernel_size=2, + stride=2, + device=device, + dtype=dtype, + ), + fl.LayerNorm2d(channels=transformer_dim // 4, device=device, dtype=dtype), + fl.GeLU(), + fl.ConvTranspose2d( + in_channels=transformer_dim // 4, + out_channels=transformer_dim // 8, + kernel_size=2, + stride=2, + device=device, + dtype=dtype, + ), + ) + + +class HQFeatures(fl.Sum): + def __init__( + self, + vit_dim: int = 1024, + transformer_dim: int = 256, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + EmbeddingEncoder(transformer_dim, device, dtype), + CompressViTFeat(transformer_dim, vit_dim, device, dtype), + ) + + +class EmbeddingMaskfeature(fl.Chain): + def __init__( + self, + transformer_dim: int = 256, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + fl.UseContext(context="mask_decoder", key="upscaled_dense_embedding"), + fl.Reshape(-1, transformer_dim, transformer_dim), + fl.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1, device=device, dtype=dtype), + fl.LayerNorm2d(transformer_dim // 4, device=device, dtype=dtype), + fl.GeLU(), + fl.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1, device=device, dtype=dtype), + ) + + +class DenseEmbeddingUpscalingHQ(fl.Sum): + def __init__( + self, + vit_dim: int = 1024, + transformer_dim: int = 256, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + EmbeddingMaskfeature(transformer_dim, device, dtype), + HQFeatures(vit_dim, transformer_dim, device, dtype), + ) + + +class HQTokenMLP(fl.Chain): + def __init__( + self, + embedding_dim: int, + num_layers: int = 3, + target_num_mask_tokens: int = 5, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + fl.Slicing(dim=1, start=target_num_mask_tokens, end=target_num_mask_tokens + 1), # HQ token + fl.MultiLinear( + input_dim=embedding_dim, + output_dim=embedding_dim // 8, + inner_dim=embedding_dim, + num_layers=num_layers, + device=device, + dtype=dtype, + ), + ) + + +class HQSAMMaskPrediction(fl.Matmul): + def __init__( + self, + embedding_dim: int, + vit_dim: int = 1024, + target_num_mask_tokens: int = 5, + num_layers: int = 3, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + HQTokenMLP( + embedding_dim, + num_layers=num_layers, + target_num_mask_tokens=target_num_mask_tokens, + device=device, + dtype=dtype, + ), + fl.Chain( + DenseEmbeddingUpscalingHQ(vit_dim=vit_dim, transformer_dim=256, device=device, dtype=dtype), + fl.Flatten(start_dim=2), + ), + ) + + +class MaskPredictionAdapter(fl.Concatenate, Adapter[MaskPrediction]): + def __init__( + self, + target: MaskPrediction, + vit_dim: int = 1024, + target_num_mask_tokens: int = 5, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + with self.setup_adapter(target): + super().__init__( + target, + fl.Chain( + HQSAMMaskPrediction( + embedding_dim=target.embedding_dim, + vit_dim=vit_dim, + target_num_mask_tokens=target_num_mask_tokens, + num_layers=3, + device=device, + dtype=dtype, + ), + fl.Reshape(-1, target.embedding_dim, target.embedding_dim), + ), + dim=1, + ) + + @property + def hq_sam_mask_prediction(self) -> HQSAMMaskPrediction: + return self.ensure_find(HQSAMMaskPrediction) + + +class MaskDecoderTokensExtender(fl.Concatenate, Adapter[MaskDecoderTokens]): + """ + Add a new weight to the MaskDecoderTokens to store the new HQ token. + """ + + def __init__( + self, + target: MaskDecoderTokens, + ) -> None: + self._hq_token = [fl.Parameter(1, target.embedding_dim, device=target.device, dtype=target.dtype)] + with self.setup_adapter(target): + super().__init__( + target, + fl.Chain( + fl.UseContext(context="mask_decoder", key="image_embedding"), # use Context to infer batch size + self.hq_token, + ), + dim=1, + ) + + @property + def regular_tokens(self) -> fl.Parameter: + return self.target.ensure_find(fl.Parameter) + + @property + def hq_token(self) -> fl.Parameter: + return self._hq_token[0] + + +class SAMViTAdapter(fl.Chain, Adapter[SAMViT]): + """ + Add a context to the image encoder transformer to store its early ViT embedding + (first intermediate embedding of the ViT). + """ + + def __init__(self, target: SAMViT) -> None: + with self.setup_adapter(target): + super().__init__(target) + target_transformer_layer = self._find_target_transformer_layer() + assert target_transformer_layer is not None + self._transformer_layer = [target_transformer_layer] + self._set_early_vit_embedding_context = [fl.SetContext("hq_sam", "early_vit_embedding")] + + @property + def target_transformer_layer(self) -> TransformerLayer: + return self._transformer_layer[0] + + @property + def set_early_vit_embedding_context(self) -> fl.SetContext: + return self._set_early_vit_embedding_context[0] + + def _find_target_transformer_layer(self) -> TransformerLayer | None: + for transformer_layer in self.target.layers(TransformerLayer): + if transformer_layer.window_size is None: + return transformer_layer + return None + + def inject(self: "SAMViTAdapter", parent: fl.Chain | None = None) -> "SAMViTAdapter": + self.target_transformer_layer.append(self.set_early_vit_embedding_context) + return super().inject(parent) + + def eject(self) -> None: + self.target_transformer_layer.remove(self.set_early_vit_embedding_context) + super().eject() + + +class PredictionsPostProc(fl.Module): + def __init__(self, hq_mask_only: bool = False) -> None: + super().__init__() + self.hq_mask_only = hq_mask_only + + def forward( + self, masks_predictions: torch.Tensor, iou_predictions: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + hq_sam_mask = masks_predictions[:, -1:, ...] + + # The official implementation of HQ-SAM has two outputs modes: + # 1. HQ mask only + # 2. HQ mask + base SAM mask, using HQ as a correction to the base SAM mask + # Details can be found in the paper: https://arxiv.org/abs/2306.01567 (section 3.3) + # Heuristics are provided by the authors here: https://github.com/SysCV/sam-hq/blob/3224888/demo/demo_hqsam_pip_example.py#L73-L75 + if self.hq_mask_only: + return (hq_sam_mask, iou_predictions) + + base_sam_masks = masks_predictions[:, :-1, ...] + assert base_sam_masks.shape[1] == 1 + return (hq_sam_mask + base_sam_masks, iou_predictions) + + +class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]): + """Adapter for SAM introducing HQ features. + + See [[arXiv:2306.01567] Segment Anything in High Quality](https://arxiv.org/abs/2306.01567) for details. + """ + + def init_context(self) -> Contexts: + return {"hq_sam": {"early_vit_embedding": None}} + + def __init__( + self, + target: SegmentAnything, + hq_mask_only: bool = False, + weights: dict[str, torch.Tensor] | None = None, + ) -> None: + self.vit_embedding_dim = target.image_encoder.embedding_dim + self.target_num_mask_tokens = target.mask_decoder.num_multimask_outputs + 2 + + with self.setup_adapter(target): + super().__init__(target) + + if target.mask_decoder.multimask_output: + raise NotImplementedError("Multi-mask mode is not supported in HQSAMAdapter.") + + mask_prediction = target.mask_decoder.ensure_find(MaskPrediction) + + self._mask_prediction_adapter = [ + MaskPredictionAdapter( + mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype + ) + ] + self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)] + self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)] + + mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens) + self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)] + + if weights is not None: + hq_token_prefix = "MaskDecoderTokensExtender.hq_token." + hq_token_state_dict: dict[str, torch.Tensor] = { + k.removeprefix(hq_token_prefix): v for k, v in weights.items() if k.startswith(hq_token_prefix) + } + self.mask_decoder_tokens_extender.hq_token.load_state_dict(hq_token_state_dict) + + mask_pred_prefix = "Chain.HQSAMMaskPrediction." + mask_pred_state_dict: dict[str, torch.Tensor] = { + k.removeprefix(mask_pred_prefix): v for k, v in weights.items() if k.startswith(mask_pred_prefix) + } + self.mask_prediction_adapter.hq_sam_mask_prediction.load_state_dict(mask_pred_state_dict) + + self.to(device=target.device, dtype=target.dtype) + + @property + def mask_decoder_tokens_extender(self) -> MaskDecoderTokensExtender: + return self._mask_decoder_tokens_extender[0] + + @property + def mask_prediction_adapter(self) -> MaskPredictionAdapter: + return self._mask_prediction_adapter[0] + + @property + def image_encoder_adapter(self) -> SAMViTAdapter: + return self._image_encoder_adapter[0] + + @property + def predictions_post_proc(self) -> PredictionsPostProc: + return self._predictions_post_proc[0] + + @property + def hq_mask_only(self) -> bool: + return self.predictions_post_proc.hq_mask_only + + @hq_mask_only.setter + def hq_mask_only(self, value: bool) -> None: + self.predictions_post_proc.hq_mask_only = value + + def inject(self: "HQSAMAdapter", parent: fl.Chain | None = None) -> "HQSAMAdapter": + self.mask_decoder_tokens_extender.inject() + self.mask_prediction_adapter.inject() + self.image_encoder_adapter.inject() + self.target.mask_decoder.insert_after_type(Predictions, self.predictions_post_proc) + return super().inject(parent) + + def eject(self) -> None: + self.mask_decoder_tokens_extender.eject() + self.mask_prediction_adapter.eject() + self.image_encoder_adapter.eject() + self.target.mask_decoder.remove(self.predictions_post_proc) + super().eject() diff --git a/src/refiners/foundationals/segment_anything/mask_decoder.py b/src/refiners/foundationals/segment_anything/mask_decoder.py index 0827dd0..25b17aa 100644 --- a/src/refiners/foundationals/segment_anything/mask_decoder.py +++ b/src/refiners/foundationals/segment_anything/mask_decoder.py @@ -1,5 +1,5 @@ import torch -from torch import Tensor, device as Device, dtype as DType, nn +from torch import Tensor, device as Device, dtype as DType import refiners.fluxion.layers as fl from refiners.fluxion.context import Contexts @@ -10,7 +10,7 @@ from refiners.foundationals.segment_anything.transformer import ( class EmbeddingsAggregator(fl.ContextModule): - def forward(self, iou_mask_tokens: Tensor) -> Tensor: + def forward(self, tokens: Tensor) -> Tensor: mask_decoder = self.ensure_parent mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder") image_embedding = mask_decoder_context["image_embedding"] @@ -18,7 +18,7 @@ class EmbeddingsAggregator(fl.ContextModule): mask_embedding = mask_decoder_context["mask_embedding"] dense_positional_embedding = mask_decoder_context["dense_positional_embedding"] - sparse_embedding = torch.cat(tensors=(iou_mask_tokens, point_embedding), dim=1) + sparse_embedding = torch.cat(tensors=(tokens, point_embedding), dim=1) dense_embedding = (image_embedding + mask_embedding).flatten(start_dim=2).transpose(1, 2) if dense_positional_embedding.shape != dense_embedding.shape: dense_positional_embedding = dense_positional_embedding.flatten(start_dim=2).transpose(1, 2) @@ -108,10 +108,11 @@ class DenseEmbeddingUpscaling(fl.Chain): ), fl.GeLU(), fl.Flatten(start_dim=2), + fl.SetContext(context="mask_decoder", key="upscaled_dense_embedding"), ) -class IOUMaskEncoder(fl.WeightedModule): +class MaskDecoderTokens(fl.Chain): def __init__( self, embedding_dim: int = 256, @@ -119,14 +120,13 @@ class IOUMaskEncoder(fl.WeightedModule): device: Device | str | None = None, dtype: DType | None = None, ) -> None: - super().__init__() self.embedding_dim = embedding_dim self.num_mask_tokens = num_mask_tokens - # aka prompt tokens + output token (for IoU scores prediction) - self.weight = nn.Parameter(data=torch.randn(num_mask_tokens + 1, embedding_dim, device=device, dtype=dtype)) - - def forward(self) -> Tensor: - return self.weight.unsqueeze(dim=0) + # aka output tokens (single-mask output + multi-mask output) + IoU token + super().__init__( + fl.UseContext(context="mask_decoder", key="image_embedding"), # use Context to infer batch size + fl.Parameter(num_mask_tokens + 1, embedding_dim, device=device, dtype=dtype), + ) class MaskPrediction(fl.Chain): @@ -178,7 +178,6 @@ class IOUPrediction(fl.Chain): super().__init__( fl.Slicing(dim=1, start=0, end=1), - fl.Squeeze(dim=0), fl.MultiLinear( input_dim=embedding_dim, output_dim=num_mask_tokens, @@ -188,6 +187,39 @@ class IOUPrediction(fl.Chain): dtype=dtype, ), fl.Slicing(dim=-1, start=1) if multimask_output else fl.Slicing(dim=-1, start=0, end=1), + fl.Squeeze(dim=1), + ) + + +class Predictions(fl.Parallel): + def __init__( + self, + embedding_dim: int, + num_mask_tokens: int, + multimask_output: bool, + num_layers: int = 3, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.embedding_dim = embedding_dim + self.num_mask_tokens = num_mask_tokens + self.num_layers = num_layers + super().__init__( + MaskPrediction( + embedding_dim=embedding_dim, + num_mask_tokens=num_mask_tokens, + multimask_output=multimask_output, + device=device, + dtype=dtype, + ), + IOUPrediction( + embedding_dim=embedding_dim, + num_layers=num_layers, + num_mask_tokens=num_mask_tokens, + multimask_output=multimask_output, + device=device, + dtype=dtype, + ), ) @@ -213,7 +245,7 @@ class MaskDecoder(fl.Chain): num_mask_tokens = self.num_multimask_outputs + 1 super().__init__( - IOUMaskEncoder(embedding_dim=embedding_dim, num_mask_tokens=num_mask_tokens, device=device, dtype=dtype), + MaskDecoderTokens(embedding_dim=embedding_dim, num_mask_tokens=num_mask_tokens, device=device, dtype=dtype), EmbeddingsAggregator(), Transformer( *( @@ -230,22 +262,12 @@ class MaskDecoder(fl.Chain): SparseCrossDenseAttention(embedding_dim=embedding_dim, device=device, dtype=dtype), fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype), ), - fl.Parallel( - MaskPrediction( - embedding_dim=embedding_dim, - num_mask_tokens=num_mask_tokens, - multimask_output=multimask_output, - device=device, - dtype=dtype, - ), - IOUPrediction( - embedding_dim=embedding_dim, - num_layers=3, - num_mask_tokens=num_mask_tokens, - multimask_output=multimask_output, - device=device, - dtype=dtype, - ), + Predictions( + embedding_dim=embedding_dim, + num_mask_tokens=num_mask_tokens, + multimask_output=multimask_output, + device=device, + dtype=dtype, ), ) diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py index 4e3b03d..c178bd0 100644 --- a/src/refiners/foundationals/segment_anything/model.py +++ b/src/refiners/foundationals/segment_anything/model.py @@ -20,7 +20,7 @@ class ImageEmbedding: original_image_size: tuple[int, int] # (height, width) -class SegmentAnything(fl.Module): +class SegmentAnything(fl.Chain): """SegmentAnything model. See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643) @@ -47,16 +47,30 @@ class SegmentAnything(fl.Module): point_encoder: The point encoder to use. mask_encoder: The mask encoder to use. mask_decoder: The mask decoder to use. - device: The PyTorch device to use. - dtype: The PyTorch data type to use. """ - super().__init__() - self.device: Device = device if isinstance(device, Device) else Device(device=device) - self.dtype = dtype - self.image_encoder = image_encoder.to(device=self.device, dtype=self.dtype) - self.point_encoder = point_encoder.to(device=self.device, dtype=self.dtype) - self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype) - self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype) + super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder) + + self.to(device=device, dtype=dtype) + + @property + def image_encoder(self) -> SAMViT: + """The image encoder.""" + return self.ensure_find(SAMViT) + + @property + def point_encoder(self) -> PointEncoder: + """The point encoder.""" + return self.ensure_find(PointEncoder) + + @property + def mask_encoder(self) -> MaskEncoder: + """The mask encoder.""" + return self.ensure_find(MaskEncoder) + + @property + def mask_decoder(self) -> MaskDecoder: + """The mask decoder.""" + return self.ensure_find(MaskDecoder) @no_grad() def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding: @@ -259,11 +273,11 @@ class SegmentAnythingH(SegmentAnything): else: mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder() - super().__init__( - image_encoder=image_encoder, - point_encoder=point_encoder, - mask_encoder=mask_encoder, - mask_decoder=mask_decoder, - device=device, - dtype=dtype, - ) + super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder) + + self.to(device=device, dtype=dtype) + + @property + def image_encoder(self) -> SAMViTH: + """The image encoder.""" + return self.ensure_find(SAMViTH) diff --git a/src/refiners/foundationals/segment_anything/prompt_encoder.py b/src/refiners/foundationals/segment_anything/prompt_encoder.py index f062ecb..f5553f3 100644 --- a/src/refiners/foundationals/segment_anything/prompt_encoder.py +++ b/src/refiners/foundationals/segment_anything/prompt_encoder.py @@ -98,7 +98,7 @@ class PointEncoder(fl.Chain): ) -> Float[Tensor, "num_positional_features height width"]: coordinate_encoder = self.ensure_find(layer_type=CoordinateEncoder) height, width = image_embedding_size - grid = torch.ones((height, width), device=self.device, dtype=torch.float32) + grid = torch.ones((height, width), device=self.device, dtype=self.dtype) y_embedding = grid.cumsum(dim=0) - 0.5 x_embedding = grid.cumsum(dim=1) - 0.5 y_embedding = y_embedding / height diff --git a/tests/foundationals/segment_anything/conftest.py b/tests/foundationals/segment_anything/conftest.py new file mode 100644 index 0000000..bbc8e59 --- /dev/null +++ b/tests/foundationals/segment_anything/conftest.py @@ -0,0 +1,18 @@ +from pathlib import Path +from warnings import warn + +from pytest import fixture, skip + + +@fixture(scope="package") +def ref_path(test_sam_path: Path) -> Path: + return test_sam_path / "test_sam_ref" + + +@fixture(scope="package") +def sam_h_weights(test_weights_path: Path) -> Path: + sam_h_weights = test_weights_path / "segment-anything-h.safetensors" + if not sam_h_weights.is_file(): + warn(f"could not find weights at {sam_h_weights}, skipping") + skip(allow_module_level=True) + return sam_h_weights diff --git a/tests/foundationals/segment_anything/test_hq_sam.py b/tests/foundationals/segment_anything/test_hq_sam.py new file mode 100644 index 0000000..3d79eaf --- /dev/null +++ b/tests/foundationals/segment_anything/test_hq_sam.py @@ -0,0 +1,296 @@ +from pathlib import Path +from typing import cast +from warnings import warn + +import numpy as np +import pytest +import torch +from PIL import Image +from segment_anything_hq import ( # type: ignore + SamPredictor as SamPredictorHQ, + sam_model_registry as sam_model_registry_hq, # type: ignore +) +from segment_anything_hq.modeling.sam import Sam # type: ignore +from tests.foundationals.segment_anything.utils import FacebookSAM, FacebookSAMPredictorHQ, SAMPrompt +from torch import optim + +from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad +from refiners.foundationals.segment_anything.hq_sam import ( + CompressViTFeat, + EmbeddingEncoder, + HQSAMAdapter, + HQTokenMLP, + MaskDecoderTokensExtender, + PredictionsPostProc, +) +from refiners.foundationals.segment_anything.model import SegmentAnythingH + + +@pytest.fixture(scope="module") +def one_prompt() -> SAMPrompt: + return SAMPrompt(box_points=[[(4, 13), (1007, 1023)]]) + + +@pytest.fixture(scope="module") +def tennis(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "tennis.png").convert("RGB") + + +@pytest.fixture(scope="module") +def hq_adapter_weights(test_weights_path: Path) -> Path: + """Path to the HQ adapter weights in Refiners format""" + refiners_hq_adapter_sam_weights = test_weights_path / "refiners-sam-hq-vit-h.safetensors" + if not refiners_hq_adapter_sam_weights.is_file(): + warn(f"Test weights not found at {refiners_hq_adapter_sam_weights}, skipping") + pytest.skip(allow_module_level=True) + return refiners_hq_adapter_sam_weights + + +@pytest.fixture +def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH: + # HQSAMAdapter is designed to be used with single-output only, hence multimask_output=False. + sam_h = SegmentAnythingH(multimask_output=False, device=test_device) + sam_h.load_from_safetensors(tensors_path=sam_h_weights) + return sam_h + + +@pytest.fixture(scope="module") +def reference_hq_adapter_weights(test_weights_path: Path) -> Path: + """Path to the HQ adapter weights in default format""" + reference_hq_adapter_sam_weights = test_weights_path / "sam_hq_vit_h.pth" + if not reference_hq_adapter_sam_weights.is_file(): + warn(f"Test weights not found at {reference_hq_adapter_sam_weights}, skipping") + pytest.skip(allow_module_level=True) + return reference_hq_adapter_sam_weights + + +@pytest.fixture(scope="module") +def reference_sam_h(reference_hq_adapter_weights: Path, test_device: torch.device) -> FacebookSAM: + sam_h = cast(FacebookSAM, sam_model_registry_hq["vit_h"](checkpoint=reference_hq_adapter_weights)) + return sam_h.to(device=test_device) + + +@pytest.fixture(scope="module") +def reference_sam_h_predictor(reference_sam_h: FacebookSAM) -> FacebookSAMPredictorHQ: + predictor = SamPredictorHQ(cast(Sam, reference_sam_h)) + return cast(FacebookSAMPredictorHQ, predictor) + + +def test_inject_eject() -> None: + sam_h = SegmentAnythingH(multimask_output=False) + initial_repr = repr(sam_h) + adapter = HQSAMAdapter(sam_h) + assert repr(sam_h) == initial_repr + adapter.inject() + assert repr(sam_h) != initial_repr + adapter.eject() + assert repr(sam_h) == initial_repr + + +def test_multimask_forbidden() -> None: + with pytest.raises(NotImplementedError, match="not supported"): + HQSAMAdapter(target=SegmentAnythingH(multimask_output=True)) + + +def test_output_shape_hq_adapter(tennis: Image.Image, one_prompt: SAMPrompt) -> None: + sam_h = SegmentAnythingH(multimask_output=False) + HQSAMAdapter(sam_h).inject() + high_res_masks, iou_predictions, low_res_masks = sam_h.predict(tennis, **one_prompt.__dict__) + assert high_res_masks.shape == (1, 1, 1024, 1024) + assert iou_predictions.shape == (1, 1) + assert low_res_masks.shape == (1, 1, 256, 256) + + +def test_mask_decoder_tokens_extender() -> None: + sam_h = SegmentAnythingH(multimask_output=False) + sam_h.requires_grad_(False) + + # MaskDecoderTokens requires image_embedding context to be set + image_embedding = torch.randn(2, 256, 64, 64) + sam_h.mask_decoder.set_image_embedding(image_embedding) + + HQSAMAdapter(sam_h).inject() + + mask_decoder_tokens = sam_h.ensure_find(MaskDecoderTokensExtender) + + tokens_before = mask_decoder_tokens() + assert tokens_before.shape == torch.Size([2, 6, 256]) + + for p in mask_decoder_tokens.parameters(): + match p.shape: + case torch.Size([5, 256]): + assert not p.requires_grad + case torch.Size([1, 256]): + assert p.requires_grad + case _: + raise ValueError + + optimizer = optim.SGD(mask_decoder_tokens.parameters(), lr=10) + optimizer.zero_grad() + + ones = torch.ones_like(tokens_before) + loss = torch.nn.functional.mse_loss(tokens_before, ones) + loss.backward() # type: ignore + optimizer.step() + + tokens_after = mask_decoder_tokens() + + assert torch.equal(tokens_before[:, :5, :], tokens_after[:, :5, :]) + assert not torch.equal(tokens_before[:, 5, :], tokens_after[:, 5, :]) + + +@no_grad() +def test_early_vit_embedding( + sam_h: SegmentAnythingH, + hq_adapter_weights: Path, + reference_sam_h: FacebookSAM, + tennis: Image.Image, +) -> None: + HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() + + image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024))) + + _ = sam_h.image_encoder(image_tensor.to(sam_h.device)) + early_vit_embedding_refiners = sam_h.use_context(context_name="hq_sam")["early_vit_embedding"] + + _, intermediate_embeddings = reference_sam_h.image_encoder(image_tensor.to(reference_sam_h.device)) + early_vit_embedding = intermediate_embeddings[0] + + assert torch.equal(early_vit_embedding, early_vit_embedding_refiners) + + +def test_tokens(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: + HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() + + mask_decoder_tokens_extender = sam_h.mask_decoder.ensure_find(MaskDecoderTokensExtender) + + # HF Token (1, 256) + assert torch.equal(reference_sam_h.mask_decoder.hf_token.weight, mask_decoder_tokens_extender.hq_token.weight) + + # Regular Tokens (5, 256) + assert torch.equal( + torch.cat([reference_sam_h.mask_decoder.iou_token.weight, reference_sam_h.mask_decoder.mask_tokens.weight]), + mask_decoder_tokens_extender.regular_tokens.weight, + ) + + +@no_grad() +def test_compress_vit_feat(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: + HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() + + early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype) + + sam_h.set_context(context="hq_sam", value={"early_vit_embedding": early_vit_embedding}) + refiners_output = sam_h.ensure_find(CompressViTFeat)() + + reference_output = reference_sam_h.mask_decoder.compress_vit_feat(early_vit_embedding.permute(0, 3, 1, 2)) + + assert torch.equal(refiners_output, reference_output) + + +@no_grad() +def test_embedding_encoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: + HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() + + x = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype) + + sam_h.set_context(context="mask_decoder", value={"image_embedding": x}) + refiners_output = sam_h.ensure_find(EmbeddingEncoder)() + + reference_output = reference_sam_h.mask_decoder.embedding_encoder(x) + + assert torch.equal(refiners_output, reference_output) + + +@no_grad() +def test_hq_token_mlp(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None: + HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() + + x = torch.randn(1, 6, 256, device=sam_h.device, dtype=sam_h.dtype) + + refiners_output = sam_h.ensure_find(HQTokenMLP)(x) + reference_output = reference_sam_h.mask_decoder.hf_mlp(x[:, -1, :]).unsqueeze(0) + + assert torch.equal(refiners_output, reference_output) + + +@pytest.mark.parametrize("hq_mask_only", [True, False]) +def test_predictor( + sam_h: SegmentAnythingH, + hq_adapter_weights: Path, + hq_mask_only: bool, + reference_sam_h_predictor: FacebookSAMPredictorHQ, + tennis: Image.Image, + one_prompt: SAMPrompt, +) -> None: + adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() + + adapter.hq_mask_only = hq_mask_only + assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only + + # Refiners + high_res_masks, iou_predictions, low_res_masks = sam_h.predict(tennis, **one_prompt.__dict__) + refiners_high_res_mask_hq = high_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu() + refiners_low_res_mask_hq = low_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu() + iou_predictions = iou_predictions[0, :].to(dtype=torch.float32).detach().cpu() + + # Reference + reference_sam_h_predictor.set_image(np.array(tennis)) + + predictor_prompt = one_prompt.__dict__["box_points"] + masks_np, iou_predictions_np, low_res_masks_np = reference_sam_h_predictor.predict( + box=np.array(predictor_prompt).flatten(), + multimask_output=False, + hq_token_only=hq_mask_only, + ) + + reference_high_res_mask_hq = torch.from_numpy(masks_np[0, ...]).to(dtype=torch.float32) # type: ignore + reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore + iou_predictions_np = torch.from_numpy(iou_predictions_np).to(dtype=torch.float32) # type: ignore + + # NOTE: Diff on logits is relatively high, but on the same scale / even lower than base SAM logits diff (6e-3) + # See https://github.com/finegrain-ai/refiners/blob/c6b5eb24a179d48e4542d94684a70c5ef3142ab1/tests/foundationals/segment_anything/test_sam.py#L426 + assert torch.allclose( + reference_low_res_mask_hq, + refiners_low_res_mask_hq, + atol=4e-3, + ) + assert ( + torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 1 + ) # The diff on the logits above leads to an absolute diff of 1 pixel on the high res masks + assert torch.allclose( + iou_predictions_np, + torch.max(iou_predictions), + atol=1e-5, + ) + + +@no_grad() +def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None: + HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() + + batch_size = 5 + + image_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1) + mask_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1) + dense_positional_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat( + batch_size, 1, 1, 1 + ) + point_embedding = torch.randn(1, 2, 256, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1) + early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype).repeat( + batch_size, 1, 1, 1 + ) + + sam_h.mask_decoder.set_image_embedding(image_embedding) + sam_h.mask_decoder.set_mask_embedding(mask_embedding) + sam_h.mask_decoder.set_point_embedding(point_embedding) + sam_h.mask_decoder.set_dense_positional_embedding(dense_positional_embedding) + sam_h.mask_decoder.set_context( + context="hq_sam", value={"early_vit_embedding": early_vit_embedding.to(sam_h.device, sam_h.dtype)} + ) + + mask_prediction, iou_prediction = sam_h.mask_decoder() + + assert mask_prediction.shape == (batch_size, 1, 256, 256) + assert iou_prediction.shape == (batch_size, 1) + assert torch.equal(mask_prediction[0], mask_prediction[1]) diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index e40eab5..9a8f4fb 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -56,15 +56,6 @@ def facebook_sam_h_weights(test_weights_path: Path) -> Path: return sam_h_weights -@pytest.fixture(scope="module") -def sam_h_weights(test_weights_path: Path) -> Path: - sam_h_weights = test_weights_path / "segment-anything-h.safetensors" - if not sam_h_weights.is_file(): - warn(f"could not find weights at {sam_h_weights}, skipping") - pytest.skip(allow_module_level=True) - return sam_h_weights - - @pytest.fixture(scope="module") def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> FacebookSAM: from segment_anything import build_sam_vit_h # type: ignore @@ -98,11 +89,6 @@ def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> Segme @pytest.fixture(scope="module") -def ref_path(test_sam_path: Path) -> Path: - return test_sam_path / "test_sam_ref" - - -@pytest.fixture def truck(ref_path: Path) -> Image.Image: return Image.open(ref_path / "truck.jpg").convert("RGB") @@ -283,14 +269,14 @@ def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> N mapping = converter.map_state_dicts(source_args=inputs, target_args={}) assert mapping is not None - mapping["IOUMaskEncoder"] = "iou_token" + mapping["MaskDecoderTokens.Parameter"] = "iou_token" state_dict = converter._convert_state_dict( # type: ignore source_state_dict=facebook_mask_decoder.state_dict(), target_state_dict=refiners_mask_decoder.state_dict(), state_dict_mapping=mapping, ) - state_dict["IOUMaskEncoder.weight"] = torch.cat( + state_dict["MaskDecoderTokens.Parameter.weight"] = torch.cat( [facebook_mask_decoder.iou_token.weight, facebook_mask_decoder.mask_tokens.weight], dim=0 ) # type: ignore refiners_mask_decoder.load_state_dict(state_dict=state_dict) @@ -462,3 +448,26 @@ def test_mask_encoder( assert facebook_mask_input.shape == mask_input.shape assert torch.allclose(dense_embeddings, fb_dense_embeddings, atol=1e-4, rtol=1e-4) + + +@no_grad() +def test_batch_mask_decoder(sam_h: SegmentAnythingH) -> None: + batch_size = 5 + + image_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1) + mask_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1) + dense_positional_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat( + batch_size, 1, 1, 1 + ) + point_embedding = torch.randn(1, 2, 256, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1) + + sam_h.mask_decoder.set_image_embedding(image_embedding) + sam_h.mask_decoder.set_mask_embedding(mask_embedding) + sam_h.mask_decoder.set_point_embedding(point_embedding) + sam_h.mask_decoder.set_dense_positional_embedding(dense_positional_embedding) + + mask_prediction, iou_prediction = sam_h.mask_decoder() + + assert mask_prediction.shape == (batch_size, 3, 256, 256) + assert iou_prediction.shape == (batch_size, 3) + assert torch.equal(mask_prediction[0], mask_prediction[1]) diff --git a/tests/foundationals/segment_anything/test_sam_ref/README.md b/tests/foundationals/segment_anything/test_sam_ref/README.md index 0ad3a32..8abad0b 100644 --- a/tests/foundationals/segment_anything/test_sam_ref/README.md +++ b/tests/foundationals/segment_anything/test_sam_ref/README.md @@ -1,3 +1,5 @@ # Note about this data `truck.jpg` is one of the [images](https://github.com/facebookresearch/segment-anything/tree/main/notebooks/images) used in the official [segment-anything notebooks](https://github.com/facebookresearch/segment-anything/tree/main/notebooks). + +`tennis.png` is one of the [images](https://github.com/SysCV/sam-hq/tree/main/demo/input_imgs) used in the official [sam-hq demos](https://github.com/SysCV/sam-hq/tree/main/demo). diff --git a/tests/foundationals/segment_anything/test_sam_ref/tennis.png b/tests/foundationals/segment_anything/test_sam_ref/tennis.png new file mode 100644 index 0000000..726f662 Binary files /dev/null and b/tests/foundationals/segment_anything/test_sam_ref/tennis.png differ diff --git a/tests/foundationals/segment_anything/utils.py b/tests/foundationals/segment_anything/utils.py index ca3f239..b397359 100644 --- a/tests/foundationals/segment_anything/utils.py +++ b/tests/foundationals/segment_anything/utils.py @@ -54,6 +54,23 @@ class FacebookSAMPredictor: ) -> tuple[NDArray, NDArray, NDArray]: ... +class FacebookSAMPredictorHQ: + model: FacebookSAM + + def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ... + + def predict( + self, + point_coords: NDArray | None = None, + point_labels: NDArray | None = None, + box: NDArray | None = None, + mask_input: NDArray | None = None, + multimask_output: bool = True, + return_logits: bool = False, + hq_token_only: bool = False, + ) -> tuple[NDArray, NDArray, NDArray]: ... + + @dataclass class SAMPrompt: foreground_points: Sequence[tuple[float, float]] | None = None