mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
Add HQ-SAM Adapter
This commit is contained in:
parent
c6b5eb24a1
commit
a93ceff752
|
@ -44,6 +44,11 @@ test = [
|
|||
# An unofficial Python package for Meta AI's Segment Anything Model:
|
||||
# https://github.com/opengeos/segment-anything
|
||||
"segment-anything-py>=1.0",
|
||||
# Official Python package for HQ-SAM
|
||||
"segment-anything-hq>=0.3",
|
||||
# HQ-SAM missing dependency:
|
||||
# https://github.com/SysCV/sam-hq/pull/59
|
||||
"timm>=0.5.0",
|
||||
]
|
||||
conversion = [
|
||||
"diffusers>=0.26.1",
|
||||
|
@ -140,3 +145,10 @@ exclude_also = [
|
|||
[tool.typos.default]
|
||||
extend-words = { adaptee = "adaptee" }
|
||||
extend-ignore-identifiers-re = ["NDArray*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
filterwarnings = [
|
||||
"ignore::UserWarning:segment_anything_hq.modeling.tiny_vit_sam.*",
|
||||
"ignore::DeprecationWarning:timm.models.layers.*",
|
||||
"ignore::DeprecationWarning:timm.models.registry.*"
|
||||
]
|
||||
|
|
|
@ -255,6 +255,7 @@ safetensors==0.4.2
|
|||
# via transformers
|
||||
scipy==1.12.0
|
||||
# via bitsandbytes
|
||||
segment-anything-hq==0.3
|
||||
segment-anything-py==1.0
|
||||
# via refiners
|
||||
sentry-sdk==1.40.6
|
||||
|
@ -270,6 +271,7 @@ smmap==5.0.1
|
|||
# via gitdb
|
||||
sympy==1.12
|
||||
# via torch
|
||||
timm==0.9.16
|
||||
tokenizers==0.15.2
|
||||
# via transformers
|
||||
tomli==2.0.1
|
||||
|
|
81
scripts/conversion/convert_hq_segment_anything.py
Normal file
81
scripts/conversion/convert_hq_segment_anything.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
import argparse
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from refiners.fluxion.utils import load_tensors, save_to_safetensors
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Convert HQ SAM model to Refiners state_dict format")
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
dest="source_path",
|
||||
required=True,
|
||||
default="sam_hq_vit_h.pth",
|
||||
help="Path to the source model checkpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
required=True,
|
||||
default="refiners_sam_hq_vit_h.safetensors",
|
||||
help="Path to save the converted model in Refiners format.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
source_state_dict = load_tensors(args.source_path)
|
||||
|
||||
state_dict: dict[str, Tensor] = {}
|
||||
|
||||
for suffix in ["weight", "bias"]:
|
||||
state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_1.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.compress_vit_feat.0.{suffix}"
|
||||
]
|
||||
state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_1.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.embedding_encoder.0.{suffix}"
|
||||
]
|
||||
state_dict[f"EmbeddingMaskfeature.Conv2d_1.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.embedding_maskfeature.0.{suffix}"
|
||||
]
|
||||
|
||||
state_dict[f"HQFeatures.CompressViTFeat.LayerNorm2d.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.compress_vit_feat.1.{suffix}"
|
||||
]
|
||||
state_dict[f"HQFeatures.EmbeddingEncoder.LayerNorm2d.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.embedding_encoder.1.{suffix}"
|
||||
]
|
||||
state_dict[f"EmbeddingMaskfeature.LayerNorm2d.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.embedding_maskfeature.1.{suffix}"
|
||||
]
|
||||
|
||||
state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_2.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.compress_vit_feat.3.{suffix}"
|
||||
]
|
||||
state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_2.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.embedding_encoder.3.{suffix}"
|
||||
]
|
||||
state_dict[f"EmbeddingMaskfeature.Conv2d_2.{suffix}"] = source_state_dict[
|
||||
f"mask_decoder.embedding_maskfeature.3.{suffix}"
|
||||
]
|
||||
|
||||
state_dict = {f"Chain.HQSAMMaskPrediction.Chain.DenseEmbeddingUpscalingHQ.{k}": v for k, v in state_dict.items()}
|
||||
|
||||
# HQ Token
|
||||
state_dict["MaskDecoderTokensExtender.hq_token.weight"] = source_state_dict["mask_decoder.hf_token.weight"]
|
||||
|
||||
# HQ MLP
|
||||
for i in range(3):
|
||||
state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.weight"] = source_state_dict[
|
||||
f"mask_decoder.hf_mlp.layers.{i}.weight"
|
||||
]
|
||||
state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.bias"] = source_state_dict[
|
||||
f"mask_decoder.hf_mlp.layers.{i}.bias"
|
||||
]
|
||||
|
||||
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -185,17 +185,16 @@ def convert_mask_decoder(mask_decoder: nn.Module) -> dict[str, Tensor]:
|
|||
|
||||
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
|
||||
assert mapping is not None
|
||||
mapping["IOUMaskEncoder"] = "iou_token"
|
||||
mapping["MaskDecoderTokens.Parameter"] = "iou_token"
|
||||
|
||||
state_dict = converter._convert_state_dict( # type: ignore
|
||||
source_state_dict=mask_decoder.state_dict(),
|
||||
target_state_dict=refiners_mask_decoder.state_dict(),
|
||||
state_dict_mapping=mapping,
|
||||
)
|
||||
state_dict["IOUMaskEncoder.weight"] = torch.cat(
|
||||
state_dict["MaskDecoderTokens.Parameter.weight"] = torch.cat(
|
||||
tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0
|
||||
) # type: ignore
|
||||
|
||||
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
|
||||
|
||||
refiners_mask_decoder.set_image_embedding(image_embedding)
|
||||
|
@ -254,10 +253,10 @@ def main() -> None:
|
|||
mask_encoder_state_dict = convert_mask_encoder(prompt_encoder=sam_h.prompt_encoder)
|
||||
|
||||
output_state_dict = {
|
||||
**{".".join(("image_encoder", key)): value for key, value in vit_state_dict.items()},
|
||||
**{".".join(("mask_decoder", key)): value for key, value in mask_decoder_state_dict.items()},
|
||||
**{".".join(("point_encoder", key)): value for key, value in point_encoder_state_dict.items()},
|
||||
**{".".join(("mask_encoder", key)): value for key, value in mask_encoder_state_dict.items()},
|
||||
**{f"SAMViTH.{key}": value for key, value in vit_state_dict.items()},
|
||||
**{f"MaskDecoder.{key}": value for key, value in mask_decoder_state_dict.items()},
|
||||
**{f"PointEncoder.{key}": value for key, value in point_encoder_state_dict.items()},
|
||||
**{f"MaskEncoder.{key}": value for key, value in mask_encoder_state_dict.items()},
|
||||
}
|
||||
if args.half:
|
||||
output_state_dict = {key: value.half() for key, value in output_state_dict.items()}
|
||||
|
|
|
@ -366,6 +366,13 @@ def download_sam():
|
|||
)
|
||||
|
||||
|
||||
def download_hq_sam():
|
||||
weights_folder = os.path.join(test_weights_dir)
|
||||
download_file(
|
||||
"https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth", weights_folder, expected_hash="66da2472"
|
||||
)
|
||||
|
||||
|
||||
def download_dinov2():
|
||||
# For conversion
|
||||
weights_folder = os.path.join(test_weights_dir)
|
||||
|
@ -661,7 +668,16 @@ def convert_sam():
|
|||
"convert_segment_anything.py",
|
||||
"tests/weights/sam_vit_h_4b8939.pth",
|
||||
"tests/weights/segment-anything-h.safetensors",
|
||||
expected_hash="b62ad5ed",
|
||||
expected_hash="5ffb976f",
|
||||
)
|
||||
|
||||
|
||||
def convert_hq_sam():
|
||||
run_conversion_script(
|
||||
"convert_hq_segment_anything.py",
|
||||
"tests/weights/sam_hq_vit_h.pth",
|
||||
"tests/weights/refiners-sam-hq-vit-h.safetensors",
|
||||
expected_hash="b2f5e79f",
|
||||
)
|
||||
|
||||
|
||||
|
@ -769,6 +785,7 @@ def download_all():
|
|||
download_ip_adapter()
|
||||
download_t2i_adapter()
|
||||
download_sam()
|
||||
download_hq_sam()
|
||||
download_dinov2()
|
||||
download_control_lora_fooocus()
|
||||
download_lcm_base()
|
||||
|
@ -789,6 +806,7 @@ def convert_all():
|
|||
convert_ip_adapter()
|
||||
convert_t2i_adapter()
|
||||
convert_sam()
|
||||
convert_hq_sam()
|
||||
convert_dinov2()
|
||||
convert_control_lora_fooocus()
|
||||
convert_lcm_base()
|
||||
|
|
378
src/refiners/foundationals/segment_anything/hq_sam.py
Normal file
378
src/refiners/foundationals/segment_anything/hq_sam.py
Normal file
|
@ -0,0 +1,378 @@
|
|||
import torch
|
||||
from torch import device as Device, dtype as DType
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters import Adapter
|
||||
from refiners.fluxion.context import Contexts
|
||||
from refiners.foundationals.segment_anything.image_encoder import SAMViT, TransformerLayer
|
||||
from refiners.foundationals.segment_anything.mask_decoder import (
|
||||
MaskDecoderTokens,
|
||||
MaskPrediction,
|
||||
Predictions,
|
||||
)
|
||||
from refiners.foundationals.segment_anything.model import SegmentAnything
|
||||
|
||||
|
||||
class CompressViTFeat(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
transformer_dim: int = 256,
|
||||
vit_dim: int = 1024,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.UseContext(context="hq_sam", key="early_vit_embedding"),
|
||||
fl.Permute(0, 3, 1, 2),
|
||||
fl.ConvTranspose2d(
|
||||
in_channels=vit_dim,
|
||||
out_channels=transformer_dim,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.LayerNorm2d(channels=transformer_dim, device=device, dtype=dtype),
|
||||
fl.GeLU(),
|
||||
fl.ConvTranspose2d(
|
||||
in_channels=transformer_dim,
|
||||
out_channels=transformer_dim // 8,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingEncoder(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
transformer_dim: int = 256,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.UseContext(context="mask_decoder", key="image_embedding"),
|
||||
fl.ConvTranspose2d(
|
||||
in_channels=transformer_dim,
|
||||
out_channels=transformer_dim // 4,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.LayerNorm2d(channels=transformer_dim // 4, device=device, dtype=dtype),
|
||||
fl.GeLU(),
|
||||
fl.ConvTranspose2d(
|
||||
in_channels=transformer_dim // 4,
|
||||
out_channels=transformer_dim // 8,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class HQFeatures(fl.Sum):
|
||||
def __init__(
|
||||
self,
|
||||
vit_dim: int = 1024,
|
||||
transformer_dim: int = 256,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
EmbeddingEncoder(transformer_dim, device, dtype),
|
||||
CompressViTFeat(transformer_dim, vit_dim, device, dtype),
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingMaskfeature(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
transformer_dim: int = 256,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.UseContext(context="mask_decoder", key="upscaled_dense_embedding"),
|
||||
fl.Reshape(-1, transformer_dim, transformer_dim),
|
||||
fl.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1, device=device, dtype=dtype),
|
||||
fl.LayerNorm2d(transformer_dim // 4, device=device, dtype=dtype),
|
||||
fl.GeLU(),
|
||||
fl.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
|
||||
class DenseEmbeddingUpscalingHQ(fl.Sum):
|
||||
def __init__(
|
||||
self,
|
||||
vit_dim: int = 1024,
|
||||
transformer_dim: int = 256,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
EmbeddingMaskfeature(transformer_dim, device, dtype),
|
||||
HQFeatures(vit_dim, transformer_dim, device, dtype),
|
||||
)
|
||||
|
||||
|
||||
class HQTokenMLP(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_layers: int = 3,
|
||||
target_num_mask_tokens: int = 5,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Slicing(dim=1, start=target_num_mask_tokens, end=target_num_mask_tokens + 1), # HQ token
|
||||
fl.MultiLinear(
|
||||
input_dim=embedding_dim,
|
||||
output_dim=embedding_dim // 8,
|
||||
inner_dim=embedding_dim,
|
||||
num_layers=num_layers,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class HQSAMMaskPrediction(fl.Matmul):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
vit_dim: int = 1024,
|
||||
target_num_mask_tokens: int = 5,
|
||||
num_layers: int = 3,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
HQTokenMLP(
|
||||
embedding_dim,
|
||||
num_layers=num_layers,
|
||||
target_num_mask_tokens=target_num_mask_tokens,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Chain(
|
||||
DenseEmbeddingUpscalingHQ(vit_dim=vit_dim, transformer_dim=256, device=device, dtype=dtype),
|
||||
fl.Flatten(start_dim=2),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class MaskPredictionAdapter(fl.Concatenate, Adapter[MaskPrediction]):
|
||||
def __init__(
|
||||
self,
|
||||
target: MaskPrediction,
|
||||
vit_dim: int = 1024,
|
||||
target_num_mask_tokens: int = 5,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(
|
||||
target,
|
||||
fl.Chain(
|
||||
HQSAMMaskPrediction(
|
||||
embedding_dim=target.embedding_dim,
|
||||
vit_dim=vit_dim,
|
||||
target_num_mask_tokens=target_num_mask_tokens,
|
||||
num_layers=3,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Reshape(-1, target.embedding_dim, target.embedding_dim),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
@property
|
||||
def hq_sam_mask_prediction(self) -> HQSAMMaskPrediction:
|
||||
return self.ensure_find(HQSAMMaskPrediction)
|
||||
|
||||
|
||||
class MaskDecoderTokensExtender(fl.Concatenate, Adapter[MaskDecoderTokens]):
|
||||
"""
|
||||
Add a new weight to the MaskDecoderTokens to store the new HQ token.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: MaskDecoderTokens,
|
||||
) -> None:
|
||||
self._hq_token = [fl.Parameter(1, target.embedding_dim, device=target.device, dtype=target.dtype)]
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(
|
||||
target,
|
||||
fl.Chain(
|
||||
fl.UseContext(context="mask_decoder", key="image_embedding"), # use Context to infer batch size
|
||||
self.hq_token,
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
@property
|
||||
def regular_tokens(self) -> fl.Parameter:
|
||||
return self.target.ensure_find(fl.Parameter)
|
||||
|
||||
@property
|
||||
def hq_token(self) -> fl.Parameter:
|
||||
return self._hq_token[0]
|
||||
|
||||
|
||||
class SAMViTAdapter(fl.Chain, Adapter[SAMViT]):
|
||||
"""
|
||||
Add a context to the image encoder transformer to store its early ViT embedding
|
||||
(first intermediate embedding of the ViT).
|
||||
"""
|
||||
|
||||
def __init__(self, target: SAMViT) -> None:
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
target_transformer_layer = self._find_target_transformer_layer()
|
||||
assert target_transformer_layer is not None
|
||||
self._transformer_layer = [target_transformer_layer]
|
||||
self._set_early_vit_embedding_context = [fl.SetContext("hq_sam", "early_vit_embedding")]
|
||||
|
||||
@property
|
||||
def target_transformer_layer(self) -> TransformerLayer:
|
||||
return self._transformer_layer[0]
|
||||
|
||||
@property
|
||||
def set_early_vit_embedding_context(self) -> fl.SetContext:
|
||||
return self._set_early_vit_embedding_context[0]
|
||||
|
||||
def _find_target_transformer_layer(self) -> TransformerLayer | None:
|
||||
for transformer_layer in self.target.layers(TransformerLayer):
|
||||
if transformer_layer.window_size is None:
|
||||
return transformer_layer
|
||||
return None
|
||||
|
||||
def inject(self: "SAMViTAdapter", parent: fl.Chain | None = None) -> "SAMViTAdapter":
|
||||
self.target_transformer_layer.append(self.set_early_vit_embedding_context)
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
self.target_transformer_layer.remove(self.set_early_vit_embedding_context)
|
||||
super().eject()
|
||||
|
||||
|
||||
class PredictionsPostProc(fl.Module):
|
||||
def __init__(self, hq_mask_only: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.hq_mask_only = hq_mask_only
|
||||
|
||||
def forward(
|
||||
self, masks_predictions: torch.Tensor, iou_predictions: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hq_sam_mask = masks_predictions[:, -1:, ...]
|
||||
|
||||
# The official implementation of HQ-SAM has two outputs modes:
|
||||
# 1. HQ mask only
|
||||
# 2. HQ mask + base SAM mask, using HQ as a correction to the base SAM mask
|
||||
# Details can be found in the paper: https://arxiv.org/abs/2306.01567 (section 3.3)
|
||||
# Heuristics are provided by the authors here: https://github.com/SysCV/sam-hq/blob/3224888/demo/demo_hqsam_pip_example.py#L73-L75
|
||||
if self.hq_mask_only:
|
||||
return (hq_sam_mask, iou_predictions)
|
||||
|
||||
base_sam_masks = masks_predictions[:, :-1, ...]
|
||||
assert base_sam_masks.shape[1] == 1
|
||||
return (hq_sam_mask + base_sam_masks, iou_predictions)
|
||||
|
||||
|
||||
class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
|
||||
"""Adapter for SAM introducing HQ features.
|
||||
|
||||
See [[arXiv:2306.01567] Segment Anything in High Quality](https://arxiv.org/abs/2306.01567) for details.
|
||||
"""
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"hq_sam": {"early_vit_embedding": None}}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: SegmentAnything,
|
||||
hq_mask_only: bool = False,
|
||||
weights: dict[str, torch.Tensor] | None = None,
|
||||
) -> None:
|
||||
self.vit_embedding_dim = target.image_encoder.embedding_dim
|
||||
self.target_num_mask_tokens = target.mask_decoder.num_multimask_outputs + 2
|
||||
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
if target.mask_decoder.multimask_output:
|
||||
raise NotImplementedError("Multi-mask mode is not supported in HQSAMAdapter.")
|
||||
|
||||
mask_prediction = target.mask_decoder.ensure_find(MaskPrediction)
|
||||
|
||||
self._mask_prediction_adapter = [
|
||||
MaskPredictionAdapter(
|
||||
mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype
|
||||
)
|
||||
]
|
||||
self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)]
|
||||
self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)]
|
||||
|
||||
mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens)
|
||||
self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)]
|
||||
|
||||
if weights is not None:
|
||||
hq_token_prefix = "MaskDecoderTokensExtender.hq_token."
|
||||
hq_token_state_dict: dict[str, torch.Tensor] = {
|
||||
k.removeprefix(hq_token_prefix): v for k, v in weights.items() if k.startswith(hq_token_prefix)
|
||||
}
|
||||
self.mask_decoder_tokens_extender.hq_token.load_state_dict(hq_token_state_dict)
|
||||
|
||||
mask_pred_prefix = "Chain.HQSAMMaskPrediction."
|
||||
mask_pred_state_dict: dict[str, torch.Tensor] = {
|
||||
k.removeprefix(mask_pred_prefix): v for k, v in weights.items() if k.startswith(mask_pred_prefix)
|
||||
}
|
||||
self.mask_prediction_adapter.hq_sam_mask_prediction.load_state_dict(mask_pred_state_dict)
|
||||
|
||||
self.to(device=target.device, dtype=target.dtype)
|
||||
|
||||
@property
|
||||
def mask_decoder_tokens_extender(self) -> MaskDecoderTokensExtender:
|
||||
return self._mask_decoder_tokens_extender[0]
|
||||
|
||||
@property
|
||||
def mask_prediction_adapter(self) -> MaskPredictionAdapter:
|
||||
return self._mask_prediction_adapter[0]
|
||||
|
||||
@property
|
||||
def image_encoder_adapter(self) -> SAMViTAdapter:
|
||||
return self._image_encoder_adapter[0]
|
||||
|
||||
@property
|
||||
def predictions_post_proc(self) -> PredictionsPostProc:
|
||||
return self._predictions_post_proc[0]
|
||||
|
||||
@property
|
||||
def hq_mask_only(self) -> bool:
|
||||
return self.predictions_post_proc.hq_mask_only
|
||||
|
||||
@hq_mask_only.setter
|
||||
def hq_mask_only(self, value: bool) -> None:
|
||||
self.predictions_post_proc.hq_mask_only = value
|
||||
|
||||
def inject(self: "HQSAMAdapter", parent: fl.Chain | None = None) -> "HQSAMAdapter":
|
||||
self.mask_decoder_tokens_extender.inject()
|
||||
self.mask_prediction_adapter.inject()
|
||||
self.image_encoder_adapter.inject()
|
||||
self.target.mask_decoder.insert_after_type(Predictions, self.predictions_post_proc)
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
self.mask_decoder_tokens_extender.eject()
|
||||
self.mask_prediction_adapter.eject()
|
||||
self.image_encoder_adapter.eject()
|
||||
self.target.mask_decoder.remove(self.predictions_post_proc)
|
||||
super().eject()
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from torch import Tensor, device as Device, dtype as DType, nn
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.context import Contexts
|
||||
|
@ -10,7 +10,7 @@ from refiners.foundationals.segment_anything.transformer import (
|
|||
|
||||
|
||||
class EmbeddingsAggregator(fl.ContextModule):
|
||||
def forward(self, iou_mask_tokens: Tensor) -> Tensor:
|
||||
def forward(self, tokens: Tensor) -> Tensor:
|
||||
mask_decoder = self.ensure_parent
|
||||
mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder")
|
||||
image_embedding = mask_decoder_context["image_embedding"]
|
||||
|
@ -18,7 +18,7 @@ class EmbeddingsAggregator(fl.ContextModule):
|
|||
mask_embedding = mask_decoder_context["mask_embedding"]
|
||||
dense_positional_embedding = mask_decoder_context["dense_positional_embedding"]
|
||||
|
||||
sparse_embedding = torch.cat(tensors=(iou_mask_tokens, point_embedding), dim=1)
|
||||
sparse_embedding = torch.cat(tensors=(tokens, point_embedding), dim=1)
|
||||
dense_embedding = (image_embedding + mask_embedding).flatten(start_dim=2).transpose(1, 2)
|
||||
if dense_positional_embedding.shape != dense_embedding.shape:
|
||||
dense_positional_embedding = dense_positional_embedding.flatten(start_dim=2).transpose(1, 2)
|
||||
|
@ -108,10 +108,11 @@ class DenseEmbeddingUpscaling(fl.Chain):
|
|||
),
|
||||
fl.GeLU(),
|
||||
fl.Flatten(start_dim=2),
|
||||
fl.SetContext(context="mask_decoder", key="upscaled_dense_embedding"),
|
||||
)
|
||||
|
||||
|
||||
class IOUMaskEncoder(fl.WeightedModule):
|
||||
class MaskDecoderTokens(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 256,
|
||||
|
@ -119,14 +120,13 @@ class IOUMaskEncoder(fl.WeightedModule):
|
|||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_mask_tokens = num_mask_tokens
|
||||
# aka prompt tokens + output token (for IoU scores prediction)
|
||||
self.weight = nn.Parameter(data=torch.randn(num_mask_tokens + 1, embedding_dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self) -> Tensor:
|
||||
return self.weight.unsqueeze(dim=0)
|
||||
# aka output tokens (single-mask output + multi-mask output) + IoU token
|
||||
super().__init__(
|
||||
fl.UseContext(context="mask_decoder", key="image_embedding"), # use Context to infer batch size
|
||||
fl.Parameter(num_mask_tokens + 1, embedding_dim, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
|
||||
class MaskPrediction(fl.Chain):
|
||||
|
@ -178,7 +178,6 @@ class IOUPrediction(fl.Chain):
|
|||
|
||||
super().__init__(
|
||||
fl.Slicing(dim=1, start=0, end=1),
|
||||
fl.Squeeze(dim=0),
|
||||
fl.MultiLinear(
|
||||
input_dim=embedding_dim,
|
||||
output_dim=num_mask_tokens,
|
||||
|
@ -188,6 +187,39 @@ class IOUPrediction(fl.Chain):
|
|||
dtype=dtype,
|
||||
),
|
||||
fl.Slicing(dim=-1, start=1) if multimask_output else fl.Slicing(dim=-1, start=0, end=1),
|
||||
fl.Squeeze(dim=1),
|
||||
)
|
||||
|
||||
|
||||
class Predictions(fl.Parallel):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_mask_tokens: int,
|
||||
multimask_output: bool,
|
||||
num_layers: int = 3,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_mask_tokens = num_mask_tokens
|
||||
self.num_layers = num_layers
|
||||
super().__init__(
|
||||
MaskPrediction(
|
||||
embedding_dim=embedding_dim,
|
||||
num_mask_tokens=num_mask_tokens,
|
||||
multimask_output=multimask_output,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
IOUPrediction(
|
||||
embedding_dim=embedding_dim,
|
||||
num_layers=num_layers,
|
||||
num_mask_tokens=num_mask_tokens,
|
||||
multimask_output=multimask_output,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
@ -213,7 +245,7 @@ class MaskDecoder(fl.Chain):
|
|||
num_mask_tokens = self.num_multimask_outputs + 1
|
||||
|
||||
super().__init__(
|
||||
IOUMaskEncoder(embedding_dim=embedding_dim, num_mask_tokens=num_mask_tokens, device=device, dtype=dtype),
|
||||
MaskDecoderTokens(embedding_dim=embedding_dim, num_mask_tokens=num_mask_tokens, device=device, dtype=dtype),
|
||||
EmbeddingsAggregator(),
|
||||
Transformer(
|
||||
*(
|
||||
|
@ -230,23 +262,13 @@ class MaskDecoder(fl.Chain):
|
|||
SparseCrossDenseAttention(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||
),
|
||||
fl.Parallel(
|
||||
MaskPrediction(
|
||||
Predictions(
|
||||
embedding_dim=embedding_dim,
|
||||
num_mask_tokens=num_mask_tokens,
|
||||
multimask_output=multimask_output,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
IOUPrediction(
|
||||
embedding_dim=embedding_dim,
|
||||
num_layers=3,
|
||||
num_mask_tokens=num_mask_tokens,
|
||||
multimask_output=multimask_output,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
|
|
|
@ -20,7 +20,7 @@ class ImageEmbedding:
|
|||
original_image_size: tuple[int, int] # (height, width)
|
||||
|
||||
|
||||
class SegmentAnything(fl.Module):
|
||||
class SegmentAnything(fl.Chain):
|
||||
"""SegmentAnything model.
|
||||
|
||||
See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643)
|
||||
|
@ -47,16 +47,30 @@ class SegmentAnything(fl.Module):
|
|||
point_encoder: The point encoder to use.
|
||||
mask_encoder: The mask encoder to use.
|
||||
mask_decoder: The mask decoder to use.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
super().__init__()
|
||||
self.device: Device = device if isinstance(device, Device) else Device(device=device)
|
||||
self.dtype = dtype
|
||||
self.image_encoder = image_encoder.to(device=self.device, dtype=self.dtype)
|
||||
self.point_encoder = point_encoder.to(device=self.device, dtype=self.dtype)
|
||||
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
|
||||
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
|
||||
super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder)
|
||||
|
||||
self.to(device=device, dtype=dtype)
|
||||
|
||||
@property
|
||||
def image_encoder(self) -> SAMViT:
|
||||
"""The image encoder."""
|
||||
return self.ensure_find(SAMViT)
|
||||
|
||||
@property
|
||||
def point_encoder(self) -> PointEncoder:
|
||||
"""The point encoder."""
|
||||
return self.ensure_find(PointEncoder)
|
||||
|
||||
@property
|
||||
def mask_encoder(self) -> MaskEncoder:
|
||||
"""The mask encoder."""
|
||||
return self.ensure_find(MaskEncoder)
|
||||
|
||||
@property
|
||||
def mask_decoder(self) -> MaskDecoder:
|
||||
"""The mask decoder."""
|
||||
return self.ensure_find(MaskDecoder)
|
||||
|
||||
@no_grad()
|
||||
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
|
||||
|
@ -259,11 +273,11 @@ class SegmentAnythingH(SegmentAnything):
|
|||
else:
|
||||
mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()
|
||||
|
||||
super().__init__(
|
||||
image_encoder=image_encoder,
|
||||
point_encoder=point_encoder,
|
||||
mask_encoder=mask_encoder,
|
||||
mask_decoder=mask_decoder,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder)
|
||||
|
||||
self.to(device=device, dtype=dtype)
|
||||
|
||||
@property
|
||||
def image_encoder(self) -> SAMViTH:
|
||||
"""The image encoder."""
|
||||
return self.ensure_find(SAMViTH)
|
||||
|
|
|
@ -98,7 +98,7 @@ class PointEncoder(fl.Chain):
|
|||
) -> Float[Tensor, "num_positional_features height width"]:
|
||||
coordinate_encoder = self.ensure_find(layer_type=CoordinateEncoder)
|
||||
height, width = image_embedding_size
|
||||
grid = torch.ones((height, width), device=self.device, dtype=torch.float32)
|
||||
grid = torch.ones((height, width), device=self.device, dtype=self.dtype)
|
||||
y_embedding = grid.cumsum(dim=0) - 0.5
|
||||
x_embedding = grid.cumsum(dim=1) - 0.5
|
||||
y_embedding = y_embedding / height
|
||||
|
|
18
tests/foundationals/segment_anything/conftest.py
Normal file
18
tests/foundationals/segment_anything/conftest.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
from pathlib import Path
|
||||
from warnings import warn
|
||||
|
||||
from pytest import fixture, skip
|
||||
|
||||
|
||||
@fixture(scope="package")
|
||||
def ref_path(test_sam_path: Path) -> Path:
|
||||
return test_sam_path / "test_sam_ref"
|
||||
|
||||
|
||||
@fixture(scope="package")
|
||||
def sam_h_weights(test_weights_path: Path) -> Path:
|
||||
sam_h_weights = test_weights_path / "segment-anything-h.safetensors"
|
||||
if not sam_h_weights.is_file():
|
||||
warn(f"could not find weights at {sam_h_weights}, skipping")
|
||||
skip(allow_module_level=True)
|
||||
return sam_h_weights
|
296
tests/foundationals/segment_anything/test_hq_sam.py
Normal file
296
tests/foundationals/segment_anything/test_hq_sam.py
Normal file
|
@ -0,0 +1,296 @@
|
|||
from pathlib import Path
|
||||
from typing import cast
|
||||
from warnings import warn
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from segment_anything_hq import ( # type: ignore
|
||||
SamPredictor as SamPredictorHQ,
|
||||
sam_model_registry as sam_model_registry_hq, # type: ignore
|
||||
)
|
||||
from segment_anything_hq.modeling.sam import Sam # type: ignore
|
||||
from tests.foundationals.segment_anything.utils import FacebookSAM, FacebookSAMPredictorHQ, SAMPrompt
|
||||
from torch import optim
|
||||
|
||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
|
||||
from refiners.foundationals.segment_anything.hq_sam import (
|
||||
CompressViTFeat,
|
||||
EmbeddingEncoder,
|
||||
HQSAMAdapter,
|
||||
HQTokenMLP,
|
||||
MaskDecoderTokensExtender,
|
||||
PredictionsPostProc,
|
||||
)
|
||||
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def one_prompt() -> SAMPrompt:
|
||||
return SAMPrompt(box_points=[[(4, 13), (1007, 1023)]])
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tennis(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "tennis.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def hq_adapter_weights(test_weights_path: Path) -> Path:
|
||||
"""Path to the HQ adapter weights in Refiners format"""
|
||||
refiners_hq_adapter_sam_weights = test_weights_path / "refiners-sam-hq-vit-h.safetensors"
|
||||
if not refiners_hq_adapter_sam_weights.is_file():
|
||||
warn(f"Test weights not found at {refiners_hq_adapter_sam_weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return refiners_hq_adapter_sam_weights
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
|
||||
# HQSAMAdapter is designed to be used with single-output only, hence multimask_output=False.
|
||||
sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
|
||||
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
|
||||
return sam_h
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def reference_hq_adapter_weights(test_weights_path: Path) -> Path:
|
||||
"""Path to the HQ adapter weights in default format"""
|
||||
reference_hq_adapter_sam_weights = test_weights_path / "sam_hq_vit_h.pth"
|
||||
if not reference_hq_adapter_sam_weights.is_file():
|
||||
warn(f"Test weights not found at {reference_hq_adapter_sam_weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return reference_hq_adapter_sam_weights
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def reference_sam_h(reference_hq_adapter_weights: Path, test_device: torch.device) -> FacebookSAM:
|
||||
sam_h = cast(FacebookSAM, sam_model_registry_hq["vit_h"](checkpoint=reference_hq_adapter_weights))
|
||||
return sam_h.to(device=test_device)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def reference_sam_h_predictor(reference_sam_h: FacebookSAM) -> FacebookSAMPredictorHQ:
|
||||
predictor = SamPredictorHQ(cast(Sam, reference_sam_h))
|
||||
return cast(FacebookSAMPredictorHQ, predictor)
|
||||
|
||||
|
||||
def test_inject_eject() -> None:
|
||||
sam_h = SegmentAnythingH(multimask_output=False)
|
||||
initial_repr = repr(sam_h)
|
||||
adapter = HQSAMAdapter(sam_h)
|
||||
assert repr(sam_h) == initial_repr
|
||||
adapter.inject()
|
||||
assert repr(sam_h) != initial_repr
|
||||
adapter.eject()
|
||||
assert repr(sam_h) == initial_repr
|
||||
|
||||
|
||||
def test_multimask_forbidden() -> None:
|
||||
with pytest.raises(NotImplementedError, match="not supported"):
|
||||
HQSAMAdapter(target=SegmentAnythingH(multimask_output=True))
|
||||
|
||||
|
||||
def test_output_shape_hq_adapter(tennis: Image.Image, one_prompt: SAMPrompt) -> None:
|
||||
sam_h = SegmentAnythingH(multimask_output=False)
|
||||
HQSAMAdapter(sam_h).inject()
|
||||
high_res_masks, iou_predictions, low_res_masks = sam_h.predict(tennis, **one_prompt.__dict__)
|
||||
assert high_res_masks.shape == (1, 1, 1024, 1024)
|
||||
assert iou_predictions.shape == (1, 1)
|
||||
assert low_res_masks.shape == (1, 1, 256, 256)
|
||||
|
||||
|
||||
def test_mask_decoder_tokens_extender() -> None:
|
||||
sam_h = SegmentAnythingH(multimask_output=False)
|
||||
sam_h.requires_grad_(False)
|
||||
|
||||
# MaskDecoderTokens requires image_embedding context to be set
|
||||
image_embedding = torch.randn(2, 256, 64, 64)
|
||||
sam_h.mask_decoder.set_image_embedding(image_embedding)
|
||||
|
||||
HQSAMAdapter(sam_h).inject()
|
||||
|
||||
mask_decoder_tokens = sam_h.ensure_find(MaskDecoderTokensExtender)
|
||||
|
||||
tokens_before = mask_decoder_tokens()
|
||||
assert tokens_before.shape == torch.Size([2, 6, 256])
|
||||
|
||||
for p in mask_decoder_tokens.parameters():
|
||||
match p.shape:
|
||||
case torch.Size([5, 256]):
|
||||
assert not p.requires_grad
|
||||
case torch.Size([1, 256]):
|
||||
assert p.requires_grad
|
||||
case _:
|
||||
raise ValueError
|
||||
|
||||
optimizer = optim.SGD(mask_decoder_tokens.parameters(), lr=10)
|
||||
optimizer.zero_grad()
|
||||
|
||||
ones = torch.ones_like(tokens_before)
|
||||
loss = torch.nn.functional.mse_loss(tokens_before, ones)
|
||||
loss.backward() # type: ignore
|
||||
optimizer.step()
|
||||
|
||||
tokens_after = mask_decoder_tokens()
|
||||
|
||||
assert torch.equal(tokens_before[:, :5, :], tokens_after[:, :5, :])
|
||||
assert not torch.equal(tokens_before[:, 5, :], tokens_after[:, 5, :])
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_early_vit_embedding(
|
||||
sam_h: SegmentAnythingH,
|
||||
hq_adapter_weights: Path,
|
||||
reference_sam_h: FacebookSAM,
|
||||
tennis: Image.Image,
|
||||
) -> None:
|
||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||
|
||||
image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024)))
|
||||
|
||||
_ = sam_h.image_encoder(image_tensor.to(sam_h.device))
|
||||
early_vit_embedding_refiners = sam_h.use_context(context_name="hq_sam")["early_vit_embedding"]
|
||||
|
||||
_, intermediate_embeddings = reference_sam_h.image_encoder(image_tensor.to(reference_sam_h.device))
|
||||
early_vit_embedding = intermediate_embeddings[0]
|
||||
|
||||
assert torch.equal(early_vit_embedding, early_vit_embedding_refiners)
|
||||
|
||||
|
||||
def test_tokens(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
|
||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||
|
||||
mask_decoder_tokens_extender = sam_h.mask_decoder.ensure_find(MaskDecoderTokensExtender)
|
||||
|
||||
# HF Token (1, 256)
|
||||
assert torch.equal(reference_sam_h.mask_decoder.hf_token.weight, mask_decoder_tokens_extender.hq_token.weight)
|
||||
|
||||
# Regular Tokens (5, 256)
|
||||
assert torch.equal(
|
||||
torch.cat([reference_sam_h.mask_decoder.iou_token.weight, reference_sam_h.mask_decoder.mask_tokens.weight]),
|
||||
mask_decoder_tokens_extender.regular_tokens.weight,
|
||||
)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_compress_vit_feat(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
|
||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||
|
||||
early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype)
|
||||
|
||||
sam_h.set_context(context="hq_sam", value={"early_vit_embedding": early_vit_embedding})
|
||||
refiners_output = sam_h.ensure_find(CompressViTFeat)()
|
||||
|
||||
reference_output = reference_sam_h.mask_decoder.compress_vit_feat(early_vit_embedding.permute(0, 3, 1, 2))
|
||||
|
||||
assert torch.equal(refiners_output, reference_output)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_embedding_encoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
|
||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||
|
||||
x = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype)
|
||||
|
||||
sam_h.set_context(context="mask_decoder", value={"image_embedding": x})
|
||||
refiners_output = sam_h.ensure_find(EmbeddingEncoder)()
|
||||
|
||||
reference_output = reference_sam_h.mask_decoder.embedding_encoder(x)
|
||||
|
||||
assert torch.equal(refiners_output, reference_output)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_hq_token_mlp(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
|
||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||
|
||||
x = torch.randn(1, 6, 256, device=sam_h.device, dtype=sam_h.dtype)
|
||||
|
||||
refiners_output = sam_h.ensure_find(HQTokenMLP)(x)
|
||||
reference_output = reference_sam_h.mask_decoder.hf_mlp(x[:, -1, :]).unsqueeze(0)
|
||||
|
||||
assert torch.equal(refiners_output, reference_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hq_mask_only", [True, False])
|
||||
def test_predictor(
|
||||
sam_h: SegmentAnythingH,
|
||||
hq_adapter_weights: Path,
|
||||
hq_mask_only: bool,
|
||||
reference_sam_h_predictor: FacebookSAMPredictorHQ,
|
||||
tennis: Image.Image,
|
||||
one_prompt: SAMPrompt,
|
||||
) -> None:
|
||||
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||
|
||||
adapter.hq_mask_only = hq_mask_only
|
||||
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
|
||||
|
||||
# Refiners
|
||||
high_res_masks, iou_predictions, low_res_masks = sam_h.predict(tennis, **one_prompt.__dict__)
|
||||
refiners_high_res_mask_hq = high_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
|
||||
refiners_low_res_mask_hq = low_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
|
||||
iou_predictions = iou_predictions[0, :].to(dtype=torch.float32).detach().cpu()
|
||||
|
||||
# Reference
|
||||
reference_sam_h_predictor.set_image(np.array(tennis))
|
||||
|
||||
predictor_prompt = one_prompt.__dict__["box_points"]
|
||||
masks_np, iou_predictions_np, low_res_masks_np = reference_sam_h_predictor.predict(
|
||||
box=np.array(predictor_prompt).flatten(),
|
||||
multimask_output=False,
|
||||
hq_token_only=hq_mask_only,
|
||||
)
|
||||
|
||||
reference_high_res_mask_hq = torch.from_numpy(masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
|
||||
reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
|
||||
iou_predictions_np = torch.from_numpy(iou_predictions_np).to(dtype=torch.float32) # type: ignore
|
||||
|
||||
# NOTE: Diff on logits is relatively high, but on the same scale / even lower than base SAM logits diff (6e-3)
|
||||
# See https://github.com/finegrain-ai/refiners/blob/c6b5eb24a179d48e4542d94684a70c5ef3142ab1/tests/foundationals/segment_anything/test_sam.py#L426
|
||||
assert torch.allclose(
|
||||
reference_low_res_mask_hq,
|
||||
refiners_low_res_mask_hq,
|
||||
atol=4e-3,
|
||||
)
|
||||
assert (
|
||||
torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 1
|
||||
) # The diff on the logits above leads to an absolute diff of 1 pixel on the high res masks
|
||||
assert torch.allclose(
|
||||
iou_predictions_np,
|
||||
torch.max(iou_predictions),
|
||||
atol=1e-5,
|
||||
)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None:
|
||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||
|
||||
batch_size = 5
|
||||
|
||||
image_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
|
||||
mask_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
|
||||
dense_positional_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(
|
||||
batch_size, 1, 1, 1
|
||||
)
|
||||
point_embedding = torch.randn(1, 2, 256, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1)
|
||||
early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype).repeat(
|
||||
batch_size, 1, 1, 1
|
||||
)
|
||||
|
||||
sam_h.mask_decoder.set_image_embedding(image_embedding)
|
||||
sam_h.mask_decoder.set_mask_embedding(mask_embedding)
|
||||
sam_h.mask_decoder.set_point_embedding(point_embedding)
|
||||
sam_h.mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
|
||||
sam_h.mask_decoder.set_context(
|
||||
context="hq_sam", value={"early_vit_embedding": early_vit_embedding.to(sam_h.device, sam_h.dtype)}
|
||||
)
|
||||
|
||||
mask_prediction, iou_prediction = sam_h.mask_decoder()
|
||||
|
||||
assert mask_prediction.shape == (batch_size, 1, 256, 256)
|
||||
assert iou_prediction.shape == (batch_size, 1)
|
||||
assert torch.equal(mask_prediction[0], mask_prediction[1])
|
|
@ -56,15 +56,6 @@ def facebook_sam_h_weights(test_weights_path: Path) -> Path:
|
|||
return sam_h_weights
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sam_h_weights(test_weights_path: Path) -> Path:
|
||||
sam_h_weights = test_weights_path / "segment-anything-h.safetensors"
|
||||
if not sam_h_weights.is_file():
|
||||
warn(f"could not find weights at {sam_h_weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return sam_h_weights
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> FacebookSAM:
|
||||
from segment_anything import build_sam_vit_h # type: ignore
|
||||
|
@ -98,11 +89,6 @@ def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> Segme
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ref_path(test_sam_path: Path) -> Path:
|
||||
return test_sam_path / "test_sam_ref"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def truck(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "truck.jpg").convert("RGB")
|
||||
|
||||
|
@ -283,14 +269,14 @@ def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> N
|
|||
|
||||
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
|
||||
assert mapping is not None
|
||||
mapping["IOUMaskEncoder"] = "iou_token"
|
||||
mapping["MaskDecoderTokens.Parameter"] = "iou_token"
|
||||
|
||||
state_dict = converter._convert_state_dict( # type: ignore
|
||||
source_state_dict=facebook_mask_decoder.state_dict(),
|
||||
target_state_dict=refiners_mask_decoder.state_dict(),
|
||||
state_dict_mapping=mapping,
|
||||
)
|
||||
state_dict["IOUMaskEncoder.weight"] = torch.cat(
|
||||
state_dict["MaskDecoderTokens.Parameter.weight"] = torch.cat(
|
||||
[facebook_mask_decoder.iou_token.weight, facebook_mask_decoder.mask_tokens.weight], dim=0
|
||||
) # type: ignore
|
||||
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
|
||||
|
@ -462,3 +448,26 @@ def test_mask_encoder(
|
|||
|
||||
assert facebook_mask_input.shape == mask_input.shape
|
||||
assert torch.allclose(dense_embeddings, fb_dense_embeddings, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_batch_mask_decoder(sam_h: SegmentAnythingH) -> None:
|
||||
batch_size = 5
|
||||
|
||||
image_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
|
||||
mask_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
|
||||
dense_positional_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(
|
||||
batch_size, 1, 1, 1
|
||||
)
|
||||
point_embedding = torch.randn(1, 2, 256, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1)
|
||||
|
||||
sam_h.mask_decoder.set_image_embedding(image_embedding)
|
||||
sam_h.mask_decoder.set_mask_embedding(mask_embedding)
|
||||
sam_h.mask_decoder.set_point_embedding(point_embedding)
|
||||
sam_h.mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
|
||||
|
||||
mask_prediction, iou_prediction = sam_h.mask_decoder()
|
||||
|
||||
assert mask_prediction.shape == (batch_size, 3, 256, 256)
|
||||
assert iou_prediction.shape == (batch_size, 3)
|
||||
assert torch.equal(mask_prediction[0], mask_prediction[1])
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# Note about this data
|
||||
|
||||
`truck.jpg` is one of the [images](https://github.com/facebookresearch/segment-anything/tree/main/notebooks/images) used in the official [segment-anything notebooks](https://github.com/facebookresearch/segment-anything/tree/main/notebooks).
|
||||
|
||||
`tennis.png` is one of the [images](https://github.com/SysCV/sam-hq/tree/main/demo/input_imgs) used in the official [sam-hq demos](https://github.com/SysCV/sam-hq/tree/main/demo).
|
||||
|
|
BIN
tests/foundationals/segment_anything/test_sam_ref/tennis.png
Normal file
BIN
tests/foundationals/segment_anything/test_sam_ref/tennis.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.2 MiB |
|
@ -54,6 +54,23 @@ class FacebookSAMPredictor:
|
|||
) -> tuple[NDArray, NDArray, NDArray]: ...
|
||||
|
||||
|
||||
class FacebookSAMPredictorHQ:
|
||||
model: FacebookSAM
|
||||
|
||||
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...
|
||||
|
||||
def predict(
|
||||
self,
|
||||
point_coords: NDArray | None = None,
|
||||
point_labels: NDArray | None = None,
|
||||
box: NDArray | None = None,
|
||||
mask_input: NDArray | None = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
hq_token_only: bool = False,
|
||||
) -> tuple[NDArray, NDArray, NDArray]: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class SAMPrompt:
|
||||
foreground_points: Sequence[tuple[float, float]] | None = None
|
||||
|
|
Loading…
Reference in a new issue