mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
Add HQ-SAM Adapter
This commit is contained in:
parent
c6b5eb24a1
commit
a93ceff752
|
@ -44,6 +44,11 @@ test = [
|
||||||
# An unofficial Python package for Meta AI's Segment Anything Model:
|
# An unofficial Python package for Meta AI's Segment Anything Model:
|
||||||
# https://github.com/opengeos/segment-anything
|
# https://github.com/opengeos/segment-anything
|
||||||
"segment-anything-py>=1.0",
|
"segment-anything-py>=1.0",
|
||||||
|
# Official Python package for HQ-SAM
|
||||||
|
"segment-anything-hq>=0.3",
|
||||||
|
# HQ-SAM missing dependency:
|
||||||
|
# https://github.com/SysCV/sam-hq/pull/59
|
||||||
|
"timm>=0.5.0",
|
||||||
]
|
]
|
||||||
conversion = [
|
conversion = [
|
||||||
"diffusers>=0.26.1",
|
"diffusers>=0.26.1",
|
||||||
|
@ -140,3 +145,10 @@ exclude_also = [
|
||||||
[tool.typos.default]
|
[tool.typos.default]
|
||||||
extend-words = { adaptee = "adaptee" }
|
extend-words = { adaptee = "adaptee" }
|
||||||
extend-ignore-identifiers-re = ["NDArray*"]
|
extend-ignore-identifiers-re = ["NDArray*"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
filterwarnings = [
|
||||||
|
"ignore::UserWarning:segment_anything_hq.modeling.tiny_vit_sam.*",
|
||||||
|
"ignore::DeprecationWarning:timm.models.layers.*",
|
||||||
|
"ignore::DeprecationWarning:timm.models.registry.*"
|
||||||
|
]
|
||||||
|
|
|
@ -255,6 +255,7 @@ safetensors==0.4.2
|
||||||
# via transformers
|
# via transformers
|
||||||
scipy==1.12.0
|
scipy==1.12.0
|
||||||
# via bitsandbytes
|
# via bitsandbytes
|
||||||
|
segment-anything-hq==0.3
|
||||||
segment-anything-py==1.0
|
segment-anything-py==1.0
|
||||||
# via refiners
|
# via refiners
|
||||||
sentry-sdk==1.40.6
|
sentry-sdk==1.40.6
|
||||||
|
@ -270,6 +271,7 @@ smmap==5.0.1
|
||||||
# via gitdb
|
# via gitdb
|
||||||
sympy==1.12
|
sympy==1.12
|
||||||
# via torch
|
# via torch
|
||||||
|
timm==0.9.16
|
||||||
tokenizers==0.15.2
|
tokenizers==0.15.2
|
||||||
# via transformers
|
# via transformers
|
||||||
tomli==2.0.1
|
tomli==2.0.1
|
||||||
|
|
81
scripts/conversion/convert_hq_segment_anything.py
Normal file
81
scripts/conversion/convert_hq_segment_anything.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import load_tensors, save_to_safetensors
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description="Convert HQ SAM model to Refiners state_dict format")
|
||||||
|
parser.add_argument(
|
||||||
|
"--from",
|
||||||
|
type=str,
|
||||||
|
dest="source_path",
|
||||||
|
required=True,
|
||||||
|
default="sam_hq_vit_h.pth",
|
||||||
|
help="Path to the source model checkpoint.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--to",
|
||||||
|
type=str,
|
||||||
|
dest="output_path",
|
||||||
|
required=True,
|
||||||
|
default="refiners_sam_hq_vit_h.safetensors",
|
||||||
|
help="Path to save the converted model in Refiners format.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
source_state_dict = load_tensors(args.source_path)
|
||||||
|
|
||||||
|
state_dict: dict[str, Tensor] = {}
|
||||||
|
|
||||||
|
for suffix in ["weight", "bias"]:
|
||||||
|
state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_1.{suffix}"] = source_state_dict[
|
||||||
|
f"mask_decoder.compress_vit_feat.0.{suffix}"
|
||||||
|
]
|
||||||
|
state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_1.{suffix}"] = source_state_dict[
|
||||||
|
f"mask_decoder.embedding_encoder.0.{suffix}"
|
||||||
|
]
|
||||||
|
state_dict[f"EmbeddingMaskfeature.Conv2d_1.{suffix}"] = source_state_dict[
|
||||||
|
f"mask_decoder.embedding_maskfeature.0.{suffix}"
|
||||||
|
]
|
||||||
|
|
||||||
|
state_dict[f"HQFeatures.CompressViTFeat.LayerNorm2d.{suffix}"] = source_state_dict[
|
||||||
|
f"mask_decoder.compress_vit_feat.1.{suffix}"
|
||||||
|
]
|
||||||
|
state_dict[f"HQFeatures.EmbeddingEncoder.LayerNorm2d.{suffix}"] = source_state_dict[
|
||||||
|
f"mask_decoder.embedding_encoder.1.{suffix}"
|
||||||
|
]
|
||||||
|
state_dict[f"EmbeddingMaskfeature.LayerNorm2d.{suffix}"] = source_state_dict[
|
||||||
|
f"mask_decoder.embedding_maskfeature.1.{suffix}"
|
||||||
|
]
|
||||||
|
|
||||||
|
state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_2.{suffix}"] = source_state_dict[
|
||||||
|
f"mask_decoder.compress_vit_feat.3.{suffix}"
|
||||||
|
]
|
||||||
|
state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_2.{suffix}"] = source_state_dict[
|
||||||
|
f"mask_decoder.embedding_encoder.3.{suffix}"
|
||||||
|
]
|
||||||
|
state_dict[f"EmbeddingMaskfeature.Conv2d_2.{suffix}"] = source_state_dict[
|
||||||
|
f"mask_decoder.embedding_maskfeature.3.{suffix}"
|
||||||
|
]
|
||||||
|
|
||||||
|
state_dict = {f"Chain.HQSAMMaskPrediction.Chain.DenseEmbeddingUpscalingHQ.{k}": v for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
# HQ Token
|
||||||
|
state_dict["MaskDecoderTokensExtender.hq_token.weight"] = source_state_dict["mask_decoder.hf_token.weight"]
|
||||||
|
|
||||||
|
# HQ MLP
|
||||||
|
for i in range(3):
|
||||||
|
state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.weight"] = source_state_dict[
|
||||||
|
f"mask_decoder.hf_mlp.layers.{i}.weight"
|
||||||
|
]
|
||||||
|
state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.bias"] = source_state_dict[
|
||||||
|
f"mask_decoder.hf_mlp.layers.{i}.bias"
|
||||||
|
]
|
||||||
|
|
||||||
|
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -185,17 +185,16 @@ def convert_mask_decoder(mask_decoder: nn.Module) -> dict[str, Tensor]:
|
||||||
|
|
||||||
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
|
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
|
||||||
assert mapping is not None
|
assert mapping is not None
|
||||||
mapping["IOUMaskEncoder"] = "iou_token"
|
mapping["MaskDecoderTokens.Parameter"] = "iou_token"
|
||||||
|
|
||||||
state_dict = converter._convert_state_dict( # type: ignore
|
state_dict = converter._convert_state_dict( # type: ignore
|
||||||
source_state_dict=mask_decoder.state_dict(),
|
source_state_dict=mask_decoder.state_dict(),
|
||||||
target_state_dict=refiners_mask_decoder.state_dict(),
|
target_state_dict=refiners_mask_decoder.state_dict(),
|
||||||
state_dict_mapping=mapping,
|
state_dict_mapping=mapping,
|
||||||
)
|
)
|
||||||
state_dict["IOUMaskEncoder.weight"] = torch.cat(
|
state_dict["MaskDecoderTokens.Parameter.weight"] = torch.cat(
|
||||||
tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0
|
tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
|
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
|
||||||
|
|
||||||
refiners_mask_decoder.set_image_embedding(image_embedding)
|
refiners_mask_decoder.set_image_embedding(image_embedding)
|
||||||
|
@ -254,10 +253,10 @@ def main() -> None:
|
||||||
mask_encoder_state_dict = convert_mask_encoder(prompt_encoder=sam_h.prompt_encoder)
|
mask_encoder_state_dict = convert_mask_encoder(prompt_encoder=sam_h.prompt_encoder)
|
||||||
|
|
||||||
output_state_dict = {
|
output_state_dict = {
|
||||||
**{".".join(("image_encoder", key)): value for key, value in vit_state_dict.items()},
|
**{f"SAMViTH.{key}": value for key, value in vit_state_dict.items()},
|
||||||
**{".".join(("mask_decoder", key)): value for key, value in mask_decoder_state_dict.items()},
|
**{f"MaskDecoder.{key}": value for key, value in mask_decoder_state_dict.items()},
|
||||||
**{".".join(("point_encoder", key)): value for key, value in point_encoder_state_dict.items()},
|
**{f"PointEncoder.{key}": value for key, value in point_encoder_state_dict.items()},
|
||||||
**{".".join(("mask_encoder", key)): value for key, value in mask_encoder_state_dict.items()},
|
**{f"MaskEncoder.{key}": value for key, value in mask_encoder_state_dict.items()},
|
||||||
}
|
}
|
||||||
if args.half:
|
if args.half:
|
||||||
output_state_dict = {key: value.half() for key, value in output_state_dict.items()}
|
output_state_dict = {key: value.half() for key, value in output_state_dict.items()}
|
||||||
|
|
|
@ -366,6 +366,13 @@ def download_sam():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_hq_sam():
|
||||||
|
weights_folder = os.path.join(test_weights_dir)
|
||||||
|
download_file(
|
||||||
|
"https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth", weights_folder, expected_hash="66da2472"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def download_dinov2():
|
def download_dinov2():
|
||||||
# For conversion
|
# For conversion
|
||||||
weights_folder = os.path.join(test_weights_dir)
|
weights_folder = os.path.join(test_weights_dir)
|
||||||
|
@ -661,7 +668,16 @@ def convert_sam():
|
||||||
"convert_segment_anything.py",
|
"convert_segment_anything.py",
|
||||||
"tests/weights/sam_vit_h_4b8939.pth",
|
"tests/weights/sam_vit_h_4b8939.pth",
|
||||||
"tests/weights/segment-anything-h.safetensors",
|
"tests/weights/segment-anything-h.safetensors",
|
||||||
expected_hash="b62ad5ed",
|
expected_hash="5ffb976f",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_hq_sam():
|
||||||
|
run_conversion_script(
|
||||||
|
"convert_hq_segment_anything.py",
|
||||||
|
"tests/weights/sam_hq_vit_h.pth",
|
||||||
|
"tests/weights/refiners-sam-hq-vit-h.safetensors",
|
||||||
|
expected_hash="b2f5e79f",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -769,6 +785,7 @@ def download_all():
|
||||||
download_ip_adapter()
|
download_ip_adapter()
|
||||||
download_t2i_adapter()
|
download_t2i_adapter()
|
||||||
download_sam()
|
download_sam()
|
||||||
|
download_hq_sam()
|
||||||
download_dinov2()
|
download_dinov2()
|
||||||
download_control_lora_fooocus()
|
download_control_lora_fooocus()
|
||||||
download_lcm_base()
|
download_lcm_base()
|
||||||
|
@ -789,6 +806,7 @@ def convert_all():
|
||||||
convert_ip_adapter()
|
convert_ip_adapter()
|
||||||
convert_t2i_adapter()
|
convert_t2i_adapter()
|
||||||
convert_sam()
|
convert_sam()
|
||||||
|
convert_hq_sam()
|
||||||
convert_dinov2()
|
convert_dinov2()
|
||||||
convert_control_lora_fooocus()
|
convert_control_lora_fooocus()
|
||||||
convert_lcm_base()
|
convert_lcm_base()
|
||||||
|
|
378
src/refiners/foundationals/segment_anything/hq_sam.py
Normal file
378
src/refiners/foundationals/segment_anything/hq_sam.py
Normal file
|
@ -0,0 +1,378 @@
|
||||||
|
import torch
|
||||||
|
from torch import device as Device, dtype as DType
|
||||||
|
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.fluxion.adapters import Adapter
|
||||||
|
from refiners.fluxion.context import Contexts
|
||||||
|
from refiners.foundationals.segment_anything.image_encoder import SAMViT, TransformerLayer
|
||||||
|
from refiners.foundationals.segment_anything.mask_decoder import (
|
||||||
|
MaskDecoderTokens,
|
||||||
|
MaskPrediction,
|
||||||
|
Predictions,
|
||||||
|
)
|
||||||
|
from refiners.foundationals.segment_anything.model import SegmentAnything
|
||||||
|
|
||||||
|
|
||||||
|
class CompressViTFeat(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transformer_dim: int = 256,
|
||||||
|
vit_dim: int = 1024,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.UseContext(context="hq_sam", key="early_vit_embedding"),
|
||||||
|
fl.Permute(0, 3, 1, 2),
|
||||||
|
fl.ConvTranspose2d(
|
||||||
|
in_channels=vit_dim,
|
||||||
|
out_channels=transformer_dim,
|
||||||
|
kernel_size=2,
|
||||||
|
stride=2,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.LayerNorm2d(channels=transformer_dim, device=device, dtype=dtype),
|
||||||
|
fl.GeLU(),
|
||||||
|
fl.ConvTranspose2d(
|
||||||
|
in_channels=transformer_dim,
|
||||||
|
out_channels=transformer_dim // 8,
|
||||||
|
kernel_size=2,
|
||||||
|
stride=2,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingEncoder(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transformer_dim: int = 256,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.UseContext(context="mask_decoder", key="image_embedding"),
|
||||||
|
fl.ConvTranspose2d(
|
||||||
|
in_channels=transformer_dim,
|
||||||
|
out_channels=transformer_dim // 4,
|
||||||
|
kernel_size=2,
|
||||||
|
stride=2,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.LayerNorm2d(channels=transformer_dim // 4, device=device, dtype=dtype),
|
||||||
|
fl.GeLU(),
|
||||||
|
fl.ConvTranspose2d(
|
||||||
|
in_channels=transformer_dim // 4,
|
||||||
|
out_channels=transformer_dim // 8,
|
||||||
|
kernel_size=2,
|
||||||
|
stride=2,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HQFeatures(fl.Sum):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vit_dim: int = 1024,
|
||||||
|
transformer_dim: int = 256,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
EmbeddingEncoder(transformer_dim, device, dtype),
|
||||||
|
CompressViTFeat(transformer_dim, vit_dim, device, dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingMaskfeature(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transformer_dim: int = 256,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.UseContext(context="mask_decoder", key="upscaled_dense_embedding"),
|
||||||
|
fl.Reshape(-1, transformer_dim, transformer_dim),
|
||||||
|
fl.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1, device=device, dtype=dtype),
|
||||||
|
fl.LayerNorm2d(transformer_dim // 4, device=device, dtype=dtype),
|
||||||
|
fl.GeLU(),
|
||||||
|
fl.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DenseEmbeddingUpscalingHQ(fl.Sum):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vit_dim: int = 1024,
|
||||||
|
transformer_dim: int = 256,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
EmbeddingMaskfeature(transformer_dim, device, dtype),
|
||||||
|
HQFeatures(vit_dim, transformer_dim, device, dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HQTokenMLP(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
num_layers: int = 3,
|
||||||
|
target_num_mask_tokens: int = 5,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.Slicing(dim=1, start=target_num_mask_tokens, end=target_num_mask_tokens + 1), # HQ token
|
||||||
|
fl.MultiLinear(
|
||||||
|
input_dim=embedding_dim,
|
||||||
|
output_dim=embedding_dim // 8,
|
||||||
|
inner_dim=embedding_dim,
|
||||||
|
num_layers=num_layers,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HQSAMMaskPrediction(fl.Matmul):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
vit_dim: int = 1024,
|
||||||
|
target_num_mask_tokens: int = 5,
|
||||||
|
num_layers: int = 3,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
HQTokenMLP(
|
||||||
|
embedding_dim,
|
||||||
|
num_layers=num_layers,
|
||||||
|
target_num_mask_tokens=target_num_mask_tokens,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
DenseEmbeddingUpscalingHQ(vit_dim=vit_dim, transformer_dim=256, device=device, dtype=dtype),
|
||||||
|
fl.Flatten(start_dim=2),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MaskPredictionAdapter(fl.Concatenate, Adapter[MaskPrediction]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: MaskPrediction,
|
||||||
|
vit_dim: int = 1024,
|
||||||
|
target_num_mask_tokens: int = 5,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(
|
||||||
|
target,
|
||||||
|
fl.Chain(
|
||||||
|
HQSAMMaskPrediction(
|
||||||
|
embedding_dim=target.embedding_dim,
|
||||||
|
vit_dim=vit_dim,
|
||||||
|
target_num_mask_tokens=target_num_mask_tokens,
|
||||||
|
num_layers=3,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.Reshape(-1, target.embedding_dim, target.embedding_dim),
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hq_sam_mask_prediction(self) -> HQSAMMaskPrediction:
|
||||||
|
return self.ensure_find(HQSAMMaskPrediction)
|
||||||
|
|
||||||
|
|
||||||
|
class MaskDecoderTokensExtender(fl.Concatenate, Adapter[MaskDecoderTokens]):
|
||||||
|
"""
|
||||||
|
Add a new weight to the MaskDecoderTokens to store the new HQ token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: MaskDecoderTokens,
|
||||||
|
) -> None:
|
||||||
|
self._hq_token = [fl.Parameter(1, target.embedding_dim, device=target.device, dtype=target.dtype)]
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(
|
||||||
|
target,
|
||||||
|
fl.Chain(
|
||||||
|
fl.UseContext(context="mask_decoder", key="image_embedding"), # use Context to infer batch size
|
||||||
|
self.hq_token,
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def regular_tokens(self) -> fl.Parameter:
|
||||||
|
return self.target.ensure_find(fl.Parameter)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hq_token(self) -> fl.Parameter:
|
||||||
|
return self._hq_token[0]
|
||||||
|
|
||||||
|
|
||||||
|
class SAMViTAdapter(fl.Chain, Adapter[SAMViT]):
|
||||||
|
"""
|
||||||
|
Add a context to the image encoder transformer to store its early ViT embedding
|
||||||
|
(first intermediate embedding of the ViT).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, target: SAMViT) -> None:
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
target_transformer_layer = self._find_target_transformer_layer()
|
||||||
|
assert target_transformer_layer is not None
|
||||||
|
self._transformer_layer = [target_transformer_layer]
|
||||||
|
self._set_early_vit_embedding_context = [fl.SetContext("hq_sam", "early_vit_embedding")]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target_transformer_layer(self) -> TransformerLayer:
|
||||||
|
return self._transformer_layer[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def set_early_vit_embedding_context(self) -> fl.SetContext:
|
||||||
|
return self._set_early_vit_embedding_context[0]
|
||||||
|
|
||||||
|
def _find_target_transformer_layer(self) -> TransformerLayer | None:
|
||||||
|
for transformer_layer in self.target.layers(TransformerLayer):
|
||||||
|
if transformer_layer.window_size is None:
|
||||||
|
return transformer_layer
|
||||||
|
return None
|
||||||
|
|
||||||
|
def inject(self: "SAMViTAdapter", parent: fl.Chain | None = None) -> "SAMViTAdapter":
|
||||||
|
self.target_transformer_layer.append(self.set_early_vit_embedding_context)
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
self.target_transformer_layer.remove(self.set_early_vit_embedding_context)
|
||||||
|
super().eject()
|
||||||
|
|
||||||
|
|
||||||
|
class PredictionsPostProc(fl.Module):
|
||||||
|
def __init__(self, hq_mask_only: bool = False) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hq_mask_only = hq_mask_only
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, masks_predictions: torch.Tensor, iou_predictions: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
hq_sam_mask = masks_predictions[:, -1:, ...]
|
||||||
|
|
||||||
|
# The official implementation of HQ-SAM has two outputs modes:
|
||||||
|
# 1. HQ mask only
|
||||||
|
# 2. HQ mask + base SAM mask, using HQ as a correction to the base SAM mask
|
||||||
|
# Details can be found in the paper: https://arxiv.org/abs/2306.01567 (section 3.3)
|
||||||
|
# Heuristics are provided by the authors here: https://github.com/SysCV/sam-hq/blob/3224888/demo/demo_hqsam_pip_example.py#L73-L75
|
||||||
|
if self.hq_mask_only:
|
||||||
|
return (hq_sam_mask, iou_predictions)
|
||||||
|
|
||||||
|
base_sam_masks = masks_predictions[:, :-1, ...]
|
||||||
|
assert base_sam_masks.shape[1] == 1
|
||||||
|
return (hq_sam_mask + base_sam_masks, iou_predictions)
|
||||||
|
|
||||||
|
|
||||||
|
class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
|
||||||
|
"""Adapter for SAM introducing HQ features.
|
||||||
|
|
||||||
|
See [[arXiv:2306.01567] Segment Anything in High Quality](https://arxiv.org/abs/2306.01567) for details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"hq_sam": {"early_vit_embedding": None}}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: SegmentAnything,
|
||||||
|
hq_mask_only: bool = False,
|
||||||
|
weights: dict[str, torch.Tensor] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.vit_embedding_dim = target.image_encoder.embedding_dim
|
||||||
|
self.target_num_mask_tokens = target.mask_decoder.num_multimask_outputs + 2
|
||||||
|
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
if target.mask_decoder.multimask_output:
|
||||||
|
raise NotImplementedError("Multi-mask mode is not supported in HQSAMAdapter.")
|
||||||
|
|
||||||
|
mask_prediction = target.mask_decoder.ensure_find(MaskPrediction)
|
||||||
|
|
||||||
|
self._mask_prediction_adapter = [
|
||||||
|
MaskPredictionAdapter(
|
||||||
|
mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype
|
||||||
|
)
|
||||||
|
]
|
||||||
|
self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)]
|
||||||
|
self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)]
|
||||||
|
|
||||||
|
mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens)
|
||||||
|
self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)]
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
hq_token_prefix = "MaskDecoderTokensExtender.hq_token."
|
||||||
|
hq_token_state_dict: dict[str, torch.Tensor] = {
|
||||||
|
k.removeprefix(hq_token_prefix): v for k, v in weights.items() if k.startswith(hq_token_prefix)
|
||||||
|
}
|
||||||
|
self.mask_decoder_tokens_extender.hq_token.load_state_dict(hq_token_state_dict)
|
||||||
|
|
||||||
|
mask_pred_prefix = "Chain.HQSAMMaskPrediction."
|
||||||
|
mask_pred_state_dict: dict[str, torch.Tensor] = {
|
||||||
|
k.removeprefix(mask_pred_prefix): v for k, v in weights.items() if k.startswith(mask_pred_prefix)
|
||||||
|
}
|
||||||
|
self.mask_prediction_adapter.hq_sam_mask_prediction.load_state_dict(mask_pred_state_dict)
|
||||||
|
|
||||||
|
self.to(device=target.device, dtype=target.dtype)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mask_decoder_tokens_extender(self) -> MaskDecoderTokensExtender:
|
||||||
|
return self._mask_decoder_tokens_extender[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mask_prediction_adapter(self) -> MaskPredictionAdapter:
|
||||||
|
return self._mask_prediction_adapter[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_encoder_adapter(self) -> SAMViTAdapter:
|
||||||
|
return self._image_encoder_adapter[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def predictions_post_proc(self) -> PredictionsPostProc:
|
||||||
|
return self._predictions_post_proc[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hq_mask_only(self) -> bool:
|
||||||
|
return self.predictions_post_proc.hq_mask_only
|
||||||
|
|
||||||
|
@hq_mask_only.setter
|
||||||
|
def hq_mask_only(self, value: bool) -> None:
|
||||||
|
self.predictions_post_proc.hq_mask_only = value
|
||||||
|
|
||||||
|
def inject(self: "HQSAMAdapter", parent: fl.Chain | None = None) -> "HQSAMAdapter":
|
||||||
|
self.mask_decoder_tokens_extender.inject()
|
||||||
|
self.mask_prediction_adapter.inject()
|
||||||
|
self.image_encoder_adapter.inject()
|
||||||
|
self.target.mask_decoder.insert_after_type(Predictions, self.predictions_post_proc)
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
self.mask_decoder_tokens_extender.eject()
|
||||||
|
self.mask_prediction_adapter.eject()
|
||||||
|
self.image_encoder_adapter.eject()
|
||||||
|
self.target.mask_decoder.remove(self.predictions_post_proc)
|
||||||
|
super().eject()
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, device as Device, dtype as DType, nn
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.context import Contexts
|
from refiners.fluxion.context import Contexts
|
||||||
|
@ -10,7 +10,7 @@ from refiners.foundationals.segment_anything.transformer import (
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsAggregator(fl.ContextModule):
|
class EmbeddingsAggregator(fl.ContextModule):
|
||||||
def forward(self, iou_mask_tokens: Tensor) -> Tensor:
|
def forward(self, tokens: Tensor) -> Tensor:
|
||||||
mask_decoder = self.ensure_parent
|
mask_decoder = self.ensure_parent
|
||||||
mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder")
|
mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder")
|
||||||
image_embedding = mask_decoder_context["image_embedding"]
|
image_embedding = mask_decoder_context["image_embedding"]
|
||||||
|
@ -18,7 +18,7 @@ class EmbeddingsAggregator(fl.ContextModule):
|
||||||
mask_embedding = mask_decoder_context["mask_embedding"]
|
mask_embedding = mask_decoder_context["mask_embedding"]
|
||||||
dense_positional_embedding = mask_decoder_context["dense_positional_embedding"]
|
dense_positional_embedding = mask_decoder_context["dense_positional_embedding"]
|
||||||
|
|
||||||
sparse_embedding = torch.cat(tensors=(iou_mask_tokens, point_embedding), dim=1)
|
sparse_embedding = torch.cat(tensors=(tokens, point_embedding), dim=1)
|
||||||
dense_embedding = (image_embedding + mask_embedding).flatten(start_dim=2).transpose(1, 2)
|
dense_embedding = (image_embedding + mask_embedding).flatten(start_dim=2).transpose(1, 2)
|
||||||
if dense_positional_embedding.shape != dense_embedding.shape:
|
if dense_positional_embedding.shape != dense_embedding.shape:
|
||||||
dense_positional_embedding = dense_positional_embedding.flatten(start_dim=2).transpose(1, 2)
|
dense_positional_embedding = dense_positional_embedding.flatten(start_dim=2).transpose(1, 2)
|
||||||
|
@ -108,10 +108,11 @@ class DenseEmbeddingUpscaling(fl.Chain):
|
||||||
),
|
),
|
||||||
fl.GeLU(),
|
fl.GeLU(),
|
||||||
fl.Flatten(start_dim=2),
|
fl.Flatten(start_dim=2),
|
||||||
|
fl.SetContext(context="mask_decoder", key="upscaled_dense_embedding"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class IOUMaskEncoder(fl.WeightedModule):
|
class MaskDecoderTokens(fl.Chain):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_dim: int = 256,
|
embedding_dim: int = 256,
|
||||||
|
@ -119,14 +120,13 @@ class IOUMaskEncoder(fl.WeightedModule):
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.num_mask_tokens = num_mask_tokens
|
self.num_mask_tokens = num_mask_tokens
|
||||||
# aka prompt tokens + output token (for IoU scores prediction)
|
# aka output tokens (single-mask output + multi-mask output) + IoU token
|
||||||
self.weight = nn.Parameter(data=torch.randn(num_mask_tokens + 1, embedding_dim, device=device, dtype=dtype))
|
super().__init__(
|
||||||
|
fl.UseContext(context="mask_decoder", key="image_embedding"), # use Context to infer batch size
|
||||||
def forward(self) -> Tensor:
|
fl.Parameter(num_mask_tokens + 1, embedding_dim, device=device, dtype=dtype),
|
||||||
return self.weight.unsqueeze(dim=0)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MaskPrediction(fl.Chain):
|
class MaskPrediction(fl.Chain):
|
||||||
|
@ -178,7 +178,6 @@ class IOUPrediction(fl.Chain):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Slicing(dim=1, start=0, end=1),
|
fl.Slicing(dim=1, start=0, end=1),
|
||||||
fl.Squeeze(dim=0),
|
|
||||||
fl.MultiLinear(
|
fl.MultiLinear(
|
||||||
input_dim=embedding_dim,
|
input_dim=embedding_dim,
|
||||||
output_dim=num_mask_tokens,
|
output_dim=num_mask_tokens,
|
||||||
|
@ -188,6 +187,39 @@ class IOUPrediction(fl.Chain):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
fl.Slicing(dim=-1, start=1) if multimask_output else fl.Slicing(dim=-1, start=0, end=1),
|
fl.Slicing(dim=-1, start=1) if multimask_output else fl.Slicing(dim=-1, start=0, end=1),
|
||||||
|
fl.Squeeze(dim=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Predictions(fl.Parallel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
num_mask_tokens: int,
|
||||||
|
multimask_output: bool,
|
||||||
|
num_layers: int = 3,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_mask_tokens = num_mask_tokens
|
||||||
|
self.num_layers = num_layers
|
||||||
|
super().__init__(
|
||||||
|
MaskPrediction(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_mask_tokens=num_mask_tokens,
|
||||||
|
multimask_output=multimask_output,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
IOUPrediction(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_mask_tokens=num_mask_tokens,
|
||||||
|
multimask_output=multimask_output,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -213,7 +245,7 @@ class MaskDecoder(fl.Chain):
|
||||||
num_mask_tokens = self.num_multimask_outputs + 1
|
num_mask_tokens = self.num_multimask_outputs + 1
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
IOUMaskEncoder(embedding_dim=embedding_dim, num_mask_tokens=num_mask_tokens, device=device, dtype=dtype),
|
MaskDecoderTokens(embedding_dim=embedding_dim, num_mask_tokens=num_mask_tokens, device=device, dtype=dtype),
|
||||||
EmbeddingsAggregator(),
|
EmbeddingsAggregator(),
|
||||||
Transformer(
|
Transformer(
|
||||||
*(
|
*(
|
||||||
|
@ -230,22 +262,12 @@ class MaskDecoder(fl.Chain):
|
||||||
SparseCrossDenseAttention(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
SparseCrossDenseAttention(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||||
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||||
),
|
),
|
||||||
fl.Parallel(
|
Predictions(
|
||||||
MaskPrediction(
|
embedding_dim=embedding_dim,
|
||||||
embedding_dim=embedding_dim,
|
num_mask_tokens=num_mask_tokens,
|
||||||
num_mask_tokens=num_mask_tokens,
|
multimask_output=multimask_output,
|
||||||
multimask_output=multimask_output,
|
device=device,
|
||||||
device=device,
|
dtype=dtype,
|
||||||
dtype=dtype,
|
|
||||||
),
|
|
||||||
IOUPrediction(
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
num_layers=3,
|
|
||||||
num_mask_tokens=num_mask_tokens,
|
|
||||||
multimask_output=multimask_output,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ class ImageEmbedding:
|
||||||
original_image_size: tuple[int, int] # (height, width)
|
original_image_size: tuple[int, int] # (height, width)
|
||||||
|
|
||||||
|
|
||||||
class SegmentAnything(fl.Module):
|
class SegmentAnything(fl.Chain):
|
||||||
"""SegmentAnything model.
|
"""SegmentAnything model.
|
||||||
|
|
||||||
See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643)
|
See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643)
|
||||||
|
@ -47,16 +47,30 @@ class SegmentAnything(fl.Module):
|
||||||
point_encoder: The point encoder to use.
|
point_encoder: The point encoder to use.
|
||||||
mask_encoder: The mask encoder to use.
|
mask_encoder: The mask encoder to use.
|
||||||
mask_decoder: The mask decoder to use.
|
mask_decoder: The mask decoder to use.
|
||||||
device: The PyTorch device to use.
|
|
||||||
dtype: The PyTorch data type to use.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder)
|
||||||
self.device: Device = device if isinstance(device, Device) else Device(device=device)
|
|
||||||
self.dtype = dtype
|
self.to(device=device, dtype=dtype)
|
||||||
self.image_encoder = image_encoder.to(device=self.device, dtype=self.dtype)
|
|
||||||
self.point_encoder = point_encoder.to(device=self.device, dtype=self.dtype)
|
@property
|
||||||
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
|
def image_encoder(self) -> SAMViT:
|
||||||
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
|
"""The image encoder."""
|
||||||
|
return self.ensure_find(SAMViT)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def point_encoder(self) -> PointEncoder:
|
||||||
|
"""The point encoder."""
|
||||||
|
return self.ensure_find(PointEncoder)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mask_encoder(self) -> MaskEncoder:
|
||||||
|
"""The mask encoder."""
|
||||||
|
return self.ensure_find(MaskEncoder)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mask_decoder(self) -> MaskDecoder:
|
||||||
|
"""The mask decoder."""
|
||||||
|
return self.ensure_find(MaskDecoder)
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
|
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
|
||||||
|
@ -259,11 +273,11 @@ class SegmentAnythingH(SegmentAnything):
|
||||||
else:
|
else:
|
||||||
mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()
|
mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder)
|
||||||
image_encoder=image_encoder,
|
|
||||||
point_encoder=point_encoder,
|
self.to(device=device, dtype=dtype)
|
||||||
mask_encoder=mask_encoder,
|
|
||||||
mask_decoder=mask_decoder,
|
@property
|
||||||
device=device,
|
def image_encoder(self) -> SAMViTH:
|
||||||
dtype=dtype,
|
"""The image encoder."""
|
||||||
)
|
return self.ensure_find(SAMViTH)
|
||||||
|
|
|
@ -98,7 +98,7 @@ class PointEncoder(fl.Chain):
|
||||||
) -> Float[Tensor, "num_positional_features height width"]:
|
) -> Float[Tensor, "num_positional_features height width"]:
|
||||||
coordinate_encoder = self.ensure_find(layer_type=CoordinateEncoder)
|
coordinate_encoder = self.ensure_find(layer_type=CoordinateEncoder)
|
||||||
height, width = image_embedding_size
|
height, width = image_embedding_size
|
||||||
grid = torch.ones((height, width), device=self.device, dtype=torch.float32)
|
grid = torch.ones((height, width), device=self.device, dtype=self.dtype)
|
||||||
y_embedding = grid.cumsum(dim=0) - 0.5
|
y_embedding = grid.cumsum(dim=0) - 0.5
|
||||||
x_embedding = grid.cumsum(dim=1) - 0.5
|
x_embedding = grid.cumsum(dim=1) - 0.5
|
||||||
y_embedding = y_embedding / height
|
y_embedding = y_embedding / height
|
||||||
|
|
18
tests/foundationals/segment_anything/conftest.py
Normal file
18
tests/foundationals/segment_anything/conftest.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
|
from pytest import fixture, skip
|
||||||
|
|
||||||
|
|
||||||
|
@fixture(scope="package")
|
||||||
|
def ref_path(test_sam_path: Path) -> Path:
|
||||||
|
return test_sam_path / "test_sam_ref"
|
||||||
|
|
||||||
|
|
||||||
|
@fixture(scope="package")
|
||||||
|
def sam_h_weights(test_weights_path: Path) -> Path:
|
||||||
|
sam_h_weights = test_weights_path / "segment-anything-h.safetensors"
|
||||||
|
if not sam_h_weights.is_file():
|
||||||
|
warn(f"could not find weights at {sam_h_weights}, skipping")
|
||||||
|
skip(allow_module_level=True)
|
||||||
|
return sam_h_weights
|
296
tests/foundationals/segment_anything/test_hq_sam.py
Normal file
296
tests/foundationals/segment_anything/test_hq_sam.py
Normal file
|
@ -0,0 +1,296 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from segment_anything_hq import ( # type: ignore
|
||||||
|
SamPredictor as SamPredictorHQ,
|
||||||
|
sam_model_registry as sam_model_registry_hq, # type: ignore
|
||||||
|
)
|
||||||
|
from segment_anything_hq.modeling.sam import Sam # type: ignore
|
||||||
|
from tests.foundationals.segment_anything.utils import FacebookSAM, FacebookSAMPredictorHQ, SAMPrompt
|
||||||
|
from torch import optim
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
|
||||||
|
from refiners.foundationals.segment_anything.hq_sam import (
|
||||||
|
CompressViTFeat,
|
||||||
|
EmbeddingEncoder,
|
||||||
|
HQSAMAdapter,
|
||||||
|
HQTokenMLP,
|
||||||
|
MaskDecoderTokensExtender,
|
||||||
|
PredictionsPostProc,
|
||||||
|
)
|
||||||
|
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def one_prompt() -> SAMPrompt:
|
||||||
|
return SAMPrompt(box_points=[[(4, 13), (1007, 1023)]])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def tennis(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "tennis.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def hq_adapter_weights(test_weights_path: Path) -> Path:
|
||||||
|
"""Path to the HQ adapter weights in Refiners format"""
|
||||||
|
refiners_hq_adapter_sam_weights = test_weights_path / "refiners-sam-hq-vit-h.safetensors"
|
||||||
|
if not refiners_hq_adapter_sam_weights.is_file():
|
||||||
|
warn(f"Test weights not found at {refiners_hq_adapter_sam_weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return refiners_hq_adapter_sam_weights
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
|
||||||
|
# HQSAMAdapter is designed to be used with single-output only, hence multimask_output=False.
|
||||||
|
sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
|
||||||
|
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
|
||||||
|
return sam_h
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def reference_hq_adapter_weights(test_weights_path: Path) -> Path:
|
||||||
|
"""Path to the HQ adapter weights in default format"""
|
||||||
|
reference_hq_adapter_sam_weights = test_weights_path / "sam_hq_vit_h.pth"
|
||||||
|
if not reference_hq_adapter_sam_weights.is_file():
|
||||||
|
warn(f"Test weights not found at {reference_hq_adapter_sam_weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return reference_hq_adapter_sam_weights
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def reference_sam_h(reference_hq_adapter_weights: Path, test_device: torch.device) -> FacebookSAM:
|
||||||
|
sam_h = cast(FacebookSAM, sam_model_registry_hq["vit_h"](checkpoint=reference_hq_adapter_weights))
|
||||||
|
return sam_h.to(device=test_device)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def reference_sam_h_predictor(reference_sam_h: FacebookSAM) -> FacebookSAMPredictorHQ:
|
||||||
|
predictor = SamPredictorHQ(cast(Sam, reference_sam_h))
|
||||||
|
return cast(FacebookSAMPredictorHQ, predictor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_eject() -> None:
|
||||||
|
sam_h = SegmentAnythingH(multimask_output=False)
|
||||||
|
initial_repr = repr(sam_h)
|
||||||
|
adapter = HQSAMAdapter(sam_h)
|
||||||
|
assert repr(sam_h) == initial_repr
|
||||||
|
adapter.inject()
|
||||||
|
assert repr(sam_h) != initial_repr
|
||||||
|
adapter.eject()
|
||||||
|
assert repr(sam_h) == initial_repr
|
||||||
|
|
||||||
|
|
||||||
|
def test_multimask_forbidden() -> None:
|
||||||
|
with pytest.raises(NotImplementedError, match="not supported"):
|
||||||
|
HQSAMAdapter(target=SegmentAnythingH(multimask_output=True))
|
||||||
|
|
||||||
|
|
||||||
|
def test_output_shape_hq_adapter(tennis: Image.Image, one_prompt: SAMPrompt) -> None:
|
||||||
|
sam_h = SegmentAnythingH(multimask_output=False)
|
||||||
|
HQSAMAdapter(sam_h).inject()
|
||||||
|
high_res_masks, iou_predictions, low_res_masks = sam_h.predict(tennis, **one_prompt.__dict__)
|
||||||
|
assert high_res_masks.shape == (1, 1, 1024, 1024)
|
||||||
|
assert iou_predictions.shape == (1, 1)
|
||||||
|
assert low_res_masks.shape == (1, 1, 256, 256)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mask_decoder_tokens_extender() -> None:
|
||||||
|
sam_h = SegmentAnythingH(multimask_output=False)
|
||||||
|
sam_h.requires_grad_(False)
|
||||||
|
|
||||||
|
# MaskDecoderTokens requires image_embedding context to be set
|
||||||
|
image_embedding = torch.randn(2, 256, 64, 64)
|
||||||
|
sam_h.mask_decoder.set_image_embedding(image_embedding)
|
||||||
|
|
||||||
|
HQSAMAdapter(sam_h).inject()
|
||||||
|
|
||||||
|
mask_decoder_tokens = sam_h.ensure_find(MaskDecoderTokensExtender)
|
||||||
|
|
||||||
|
tokens_before = mask_decoder_tokens()
|
||||||
|
assert tokens_before.shape == torch.Size([2, 6, 256])
|
||||||
|
|
||||||
|
for p in mask_decoder_tokens.parameters():
|
||||||
|
match p.shape:
|
||||||
|
case torch.Size([5, 256]):
|
||||||
|
assert not p.requires_grad
|
||||||
|
case torch.Size([1, 256]):
|
||||||
|
assert p.requires_grad
|
||||||
|
case _:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
optimizer = optim.SGD(mask_decoder_tokens.parameters(), lr=10)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
ones = torch.ones_like(tokens_before)
|
||||||
|
loss = torch.nn.functional.mse_loss(tokens_before, ones)
|
||||||
|
loss.backward() # type: ignore
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
tokens_after = mask_decoder_tokens()
|
||||||
|
|
||||||
|
assert torch.equal(tokens_before[:, :5, :], tokens_after[:, :5, :])
|
||||||
|
assert not torch.equal(tokens_before[:, 5, :], tokens_after[:, 5, :])
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_early_vit_embedding(
|
||||||
|
sam_h: SegmentAnythingH,
|
||||||
|
hq_adapter_weights: Path,
|
||||||
|
reference_sam_h: FacebookSAM,
|
||||||
|
tennis: Image.Image,
|
||||||
|
) -> None:
|
||||||
|
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||||
|
|
||||||
|
image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024)))
|
||||||
|
|
||||||
|
_ = sam_h.image_encoder(image_tensor.to(sam_h.device))
|
||||||
|
early_vit_embedding_refiners = sam_h.use_context(context_name="hq_sam")["early_vit_embedding"]
|
||||||
|
|
||||||
|
_, intermediate_embeddings = reference_sam_h.image_encoder(image_tensor.to(reference_sam_h.device))
|
||||||
|
early_vit_embedding = intermediate_embeddings[0]
|
||||||
|
|
||||||
|
assert torch.equal(early_vit_embedding, early_vit_embedding_refiners)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tokens(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
|
||||||
|
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||||
|
|
||||||
|
mask_decoder_tokens_extender = sam_h.mask_decoder.ensure_find(MaskDecoderTokensExtender)
|
||||||
|
|
||||||
|
# HF Token (1, 256)
|
||||||
|
assert torch.equal(reference_sam_h.mask_decoder.hf_token.weight, mask_decoder_tokens_extender.hq_token.weight)
|
||||||
|
|
||||||
|
# Regular Tokens (5, 256)
|
||||||
|
assert torch.equal(
|
||||||
|
torch.cat([reference_sam_h.mask_decoder.iou_token.weight, reference_sam_h.mask_decoder.mask_tokens.weight]),
|
||||||
|
mask_decoder_tokens_extender.regular_tokens.weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_compress_vit_feat(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
|
||||||
|
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||||
|
|
||||||
|
early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype)
|
||||||
|
|
||||||
|
sam_h.set_context(context="hq_sam", value={"early_vit_embedding": early_vit_embedding})
|
||||||
|
refiners_output = sam_h.ensure_find(CompressViTFeat)()
|
||||||
|
|
||||||
|
reference_output = reference_sam_h.mask_decoder.compress_vit_feat(early_vit_embedding.permute(0, 3, 1, 2))
|
||||||
|
|
||||||
|
assert torch.equal(refiners_output, reference_output)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_embedding_encoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
|
||||||
|
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||||
|
|
||||||
|
x = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype)
|
||||||
|
|
||||||
|
sam_h.set_context(context="mask_decoder", value={"image_embedding": x})
|
||||||
|
refiners_output = sam_h.ensure_find(EmbeddingEncoder)()
|
||||||
|
|
||||||
|
reference_output = reference_sam_h.mask_decoder.embedding_encoder(x)
|
||||||
|
|
||||||
|
assert torch.equal(refiners_output, reference_output)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_hq_token_mlp(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
|
||||||
|
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||||
|
|
||||||
|
x = torch.randn(1, 6, 256, device=sam_h.device, dtype=sam_h.dtype)
|
||||||
|
|
||||||
|
refiners_output = sam_h.ensure_find(HQTokenMLP)(x)
|
||||||
|
reference_output = reference_sam_h.mask_decoder.hf_mlp(x[:, -1, :]).unsqueeze(0)
|
||||||
|
|
||||||
|
assert torch.equal(refiners_output, reference_output)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("hq_mask_only", [True, False])
|
||||||
|
def test_predictor(
|
||||||
|
sam_h: SegmentAnythingH,
|
||||||
|
hq_adapter_weights: Path,
|
||||||
|
hq_mask_only: bool,
|
||||||
|
reference_sam_h_predictor: FacebookSAMPredictorHQ,
|
||||||
|
tennis: Image.Image,
|
||||||
|
one_prompt: SAMPrompt,
|
||||||
|
) -> None:
|
||||||
|
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||||
|
|
||||||
|
adapter.hq_mask_only = hq_mask_only
|
||||||
|
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
|
||||||
|
|
||||||
|
# Refiners
|
||||||
|
high_res_masks, iou_predictions, low_res_masks = sam_h.predict(tennis, **one_prompt.__dict__)
|
||||||
|
refiners_high_res_mask_hq = high_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
|
||||||
|
refiners_low_res_mask_hq = low_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
|
||||||
|
iou_predictions = iou_predictions[0, :].to(dtype=torch.float32).detach().cpu()
|
||||||
|
|
||||||
|
# Reference
|
||||||
|
reference_sam_h_predictor.set_image(np.array(tennis))
|
||||||
|
|
||||||
|
predictor_prompt = one_prompt.__dict__["box_points"]
|
||||||
|
masks_np, iou_predictions_np, low_res_masks_np = reference_sam_h_predictor.predict(
|
||||||
|
box=np.array(predictor_prompt).flatten(),
|
||||||
|
multimask_output=False,
|
||||||
|
hq_token_only=hq_mask_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
reference_high_res_mask_hq = torch.from_numpy(masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
|
||||||
|
reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
|
||||||
|
iou_predictions_np = torch.from_numpy(iou_predictions_np).to(dtype=torch.float32) # type: ignore
|
||||||
|
|
||||||
|
# NOTE: Diff on logits is relatively high, but on the same scale / even lower than base SAM logits diff (6e-3)
|
||||||
|
# See https://github.com/finegrain-ai/refiners/blob/c6b5eb24a179d48e4542d94684a70c5ef3142ab1/tests/foundationals/segment_anything/test_sam.py#L426
|
||||||
|
assert torch.allclose(
|
||||||
|
reference_low_res_mask_hq,
|
||||||
|
refiners_low_res_mask_hq,
|
||||||
|
atol=4e-3,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 1
|
||||||
|
) # The diff on the logits above leads to an absolute diff of 1 pixel on the high res masks
|
||||||
|
assert torch.allclose(
|
||||||
|
iou_predictions_np,
|
||||||
|
torch.max(iou_predictions),
|
||||||
|
atol=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None:
|
||||||
|
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||||
|
|
||||||
|
batch_size = 5
|
||||||
|
|
||||||
|
image_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
|
||||||
|
mask_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
|
||||||
|
dense_positional_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(
|
||||||
|
batch_size, 1, 1, 1
|
||||||
|
)
|
||||||
|
point_embedding = torch.randn(1, 2, 256, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1)
|
||||||
|
early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype).repeat(
|
||||||
|
batch_size, 1, 1, 1
|
||||||
|
)
|
||||||
|
|
||||||
|
sam_h.mask_decoder.set_image_embedding(image_embedding)
|
||||||
|
sam_h.mask_decoder.set_mask_embedding(mask_embedding)
|
||||||
|
sam_h.mask_decoder.set_point_embedding(point_embedding)
|
||||||
|
sam_h.mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
|
||||||
|
sam_h.mask_decoder.set_context(
|
||||||
|
context="hq_sam", value={"early_vit_embedding": early_vit_embedding.to(sam_h.device, sam_h.dtype)}
|
||||||
|
)
|
||||||
|
|
||||||
|
mask_prediction, iou_prediction = sam_h.mask_decoder()
|
||||||
|
|
||||||
|
assert mask_prediction.shape == (batch_size, 1, 256, 256)
|
||||||
|
assert iou_prediction.shape == (batch_size, 1)
|
||||||
|
assert torch.equal(mask_prediction[0], mask_prediction[1])
|
|
@ -56,15 +56,6 @@ def facebook_sam_h_weights(test_weights_path: Path) -> Path:
|
||||||
return sam_h_weights
|
return sam_h_weights
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def sam_h_weights(test_weights_path: Path) -> Path:
|
|
||||||
sam_h_weights = test_weights_path / "segment-anything-h.safetensors"
|
|
||||||
if not sam_h_weights.is_file():
|
|
||||||
warn(f"could not find weights at {sam_h_weights}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
return sam_h_weights
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> FacebookSAM:
|
def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> FacebookSAM:
|
||||||
from segment_anything import build_sam_vit_h # type: ignore
|
from segment_anything import build_sam_vit_h # type: ignore
|
||||||
|
@ -98,11 +89,6 @@ def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> Segme
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ref_path(test_sam_path: Path) -> Path:
|
|
||||||
return test_sam_path / "test_sam_ref"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def truck(ref_path: Path) -> Image.Image:
|
def truck(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "truck.jpg").convert("RGB")
|
return Image.open(ref_path / "truck.jpg").convert("RGB")
|
||||||
|
|
||||||
|
@ -283,14 +269,14 @@ def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> N
|
||||||
|
|
||||||
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
|
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
|
||||||
assert mapping is not None
|
assert mapping is not None
|
||||||
mapping["IOUMaskEncoder"] = "iou_token"
|
mapping["MaskDecoderTokens.Parameter"] = "iou_token"
|
||||||
|
|
||||||
state_dict = converter._convert_state_dict( # type: ignore
|
state_dict = converter._convert_state_dict( # type: ignore
|
||||||
source_state_dict=facebook_mask_decoder.state_dict(),
|
source_state_dict=facebook_mask_decoder.state_dict(),
|
||||||
target_state_dict=refiners_mask_decoder.state_dict(),
|
target_state_dict=refiners_mask_decoder.state_dict(),
|
||||||
state_dict_mapping=mapping,
|
state_dict_mapping=mapping,
|
||||||
)
|
)
|
||||||
state_dict["IOUMaskEncoder.weight"] = torch.cat(
|
state_dict["MaskDecoderTokens.Parameter.weight"] = torch.cat(
|
||||||
[facebook_mask_decoder.iou_token.weight, facebook_mask_decoder.mask_tokens.weight], dim=0
|
[facebook_mask_decoder.iou_token.weight, facebook_mask_decoder.mask_tokens.weight], dim=0
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
|
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
|
||||||
|
@ -462,3 +448,26 @@ def test_mask_encoder(
|
||||||
|
|
||||||
assert facebook_mask_input.shape == mask_input.shape
|
assert facebook_mask_input.shape == mask_input.shape
|
||||||
assert torch.allclose(dense_embeddings, fb_dense_embeddings, atol=1e-4, rtol=1e-4)
|
assert torch.allclose(dense_embeddings, fb_dense_embeddings, atol=1e-4, rtol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_batch_mask_decoder(sam_h: SegmentAnythingH) -> None:
|
||||||
|
batch_size = 5
|
||||||
|
|
||||||
|
image_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
|
||||||
|
mask_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
|
||||||
|
dense_positional_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(
|
||||||
|
batch_size, 1, 1, 1
|
||||||
|
)
|
||||||
|
point_embedding = torch.randn(1, 2, 256, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1)
|
||||||
|
|
||||||
|
sam_h.mask_decoder.set_image_embedding(image_embedding)
|
||||||
|
sam_h.mask_decoder.set_mask_embedding(mask_embedding)
|
||||||
|
sam_h.mask_decoder.set_point_embedding(point_embedding)
|
||||||
|
sam_h.mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
|
||||||
|
|
||||||
|
mask_prediction, iou_prediction = sam_h.mask_decoder()
|
||||||
|
|
||||||
|
assert mask_prediction.shape == (batch_size, 3, 256, 256)
|
||||||
|
assert iou_prediction.shape == (batch_size, 3)
|
||||||
|
assert torch.equal(mask_prediction[0], mask_prediction[1])
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
# Note about this data
|
# Note about this data
|
||||||
|
|
||||||
`truck.jpg` is one of the [images](https://github.com/facebookresearch/segment-anything/tree/main/notebooks/images) used in the official [segment-anything notebooks](https://github.com/facebookresearch/segment-anything/tree/main/notebooks).
|
`truck.jpg` is one of the [images](https://github.com/facebookresearch/segment-anything/tree/main/notebooks/images) used in the official [segment-anything notebooks](https://github.com/facebookresearch/segment-anything/tree/main/notebooks).
|
||||||
|
|
||||||
|
`tennis.png` is one of the [images](https://github.com/SysCV/sam-hq/tree/main/demo/input_imgs) used in the official [sam-hq demos](https://github.com/SysCV/sam-hq/tree/main/demo).
|
||||||
|
|
BIN
tests/foundationals/segment_anything/test_sam_ref/tennis.png
Normal file
BIN
tests/foundationals/segment_anything/test_sam_ref/tennis.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.2 MiB |
|
@ -54,6 +54,23 @@ class FacebookSAMPredictor:
|
||||||
) -> tuple[NDArray, NDArray, NDArray]: ...
|
) -> tuple[NDArray, NDArray, NDArray]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class FacebookSAMPredictorHQ:
|
||||||
|
model: FacebookSAM
|
||||||
|
|
||||||
|
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
point_coords: NDArray | None = None,
|
||||||
|
point_labels: NDArray | None = None,
|
||||||
|
box: NDArray | None = None,
|
||||||
|
mask_input: NDArray | None = None,
|
||||||
|
multimask_output: bool = True,
|
||||||
|
return_logits: bool = False,
|
||||||
|
hq_token_only: bool = False,
|
||||||
|
) -> tuple[NDArray, NDArray, NDArray]: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SAMPrompt:
|
class SAMPrompt:
|
||||||
foreground_points: Sequence[tuple[float, float]] | None = None
|
foreground_points: Sequence[tuple[float, float]] | None = None
|
||||||
|
|
Loading…
Reference in a new issue