Add HQ-SAM Adapter

This commit is contained in:
hugojarkoff 2024-03-21 13:59:36 +00:00 committed by hugojarkoff
parent c6b5eb24a1
commit a93ceff752
15 changed files with 939 additions and 71 deletions

View file

@ -44,6 +44,11 @@ test = [
# An unofficial Python package for Meta AI's Segment Anything Model:
# https://github.com/opengeos/segment-anything
"segment-anything-py>=1.0",
# Official Python package for HQ-SAM
"segment-anything-hq>=0.3",
# HQ-SAM missing dependency:
# https://github.com/SysCV/sam-hq/pull/59
"timm>=0.5.0",
]
conversion = [
"diffusers>=0.26.1",
@ -140,3 +145,10 @@ exclude_also = [
[tool.typos.default]
extend-words = { adaptee = "adaptee" }
extend-ignore-identifiers-re = ["NDArray*"]
[tool.pytest.ini_options]
filterwarnings = [
"ignore::UserWarning:segment_anything_hq.modeling.tiny_vit_sam.*",
"ignore::DeprecationWarning:timm.models.layers.*",
"ignore::DeprecationWarning:timm.models.registry.*"
]

View file

@ -255,6 +255,7 @@ safetensors==0.4.2
# via transformers
scipy==1.12.0
# via bitsandbytes
segment-anything-hq==0.3
segment-anything-py==1.0
# via refiners
sentry-sdk==1.40.6
@ -270,6 +271,7 @@ smmap==5.0.1
# via gitdb
sympy==1.12
# via torch
timm==0.9.16
tokenizers==0.15.2
# via transformers
tomli==2.0.1

View file

@ -0,0 +1,81 @@
import argparse
from torch import Tensor
from refiners.fluxion.utils import load_tensors, save_to_safetensors
def main() -> None:
parser = argparse.ArgumentParser(description="Convert HQ SAM model to Refiners state_dict format")
parser.add_argument(
"--from",
type=str,
dest="source_path",
required=True,
default="sam_hq_vit_h.pth",
help="Path to the source model checkpoint.",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
required=True,
default="refiners_sam_hq_vit_h.safetensors",
help="Path to save the converted model in Refiners format.",
)
args = parser.parse_args()
source_state_dict = load_tensors(args.source_path)
state_dict: dict[str, Tensor] = {}
for suffix in ["weight", "bias"]:
state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_1.{suffix}"] = source_state_dict[
f"mask_decoder.compress_vit_feat.0.{suffix}"
]
state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_1.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_encoder.0.{suffix}"
]
state_dict[f"EmbeddingMaskfeature.Conv2d_1.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_maskfeature.0.{suffix}"
]
state_dict[f"HQFeatures.CompressViTFeat.LayerNorm2d.{suffix}"] = source_state_dict[
f"mask_decoder.compress_vit_feat.1.{suffix}"
]
state_dict[f"HQFeatures.EmbeddingEncoder.LayerNorm2d.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_encoder.1.{suffix}"
]
state_dict[f"EmbeddingMaskfeature.LayerNorm2d.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_maskfeature.1.{suffix}"
]
state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_2.{suffix}"] = source_state_dict[
f"mask_decoder.compress_vit_feat.3.{suffix}"
]
state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_2.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_encoder.3.{suffix}"
]
state_dict[f"EmbeddingMaskfeature.Conv2d_2.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_maskfeature.3.{suffix}"
]
state_dict = {f"Chain.HQSAMMaskPrediction.Chain.DenseEmbeddingUpscalingHQ.{k}": v for k, v in state_dict.items()}
# HQ Token
state_dict["MaskDecoderTokensExtender.hq_token.weight"] = source_state_dict["mask_decoder.hf_token.weight"]
# HQ MLP
for i in range(3):
state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.weight"] = source_state_dict[
f"mask_decoder.hf_mlp.layers.{i}.weight"
]
state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.bias"] = source_state_dict[
f"mask_decoder.hf_mlp.layers.{i}.bias"
]
save_to_safetensors(path=args.output_path, tensors=state_dict)
if __name__ == "__main__":
main()

View file

@ -185,17 +185,16 @@ def convert_mask_decoder(mask_decoder: nn.Module) -> dict[str, Tensor]:
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
assert mapping is not None
mapping["IOUMaskEncoder"] = "iou_token"
mapping["MaskDecoderTokens.Parameter"] = "iou_token"
state_dict = converter._convert_state_dict( # type: ignore
source_state_dict=mask_decoder.state_dict(),
target_state_dict=refiners_mask_decoder.state_dict(),
state_dict_mapping=mapping,
)
state_dict["IOUMaskEncoder.weight"] = torch.cat(
state_dict["MaskDecoderTokens.Parameter.weight"] = torch.cat(
tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0
) # type: ignore
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
refiners_mask_decoder.set_image_embedding(image_embedding)
@ -254,10 +253,10 @@ def main() -> None:
mask_encoder_state_dict = convert_mask_encoder(prompt_encoder=sam_h.prompt_encoder)
output_state_dict = {
**{".".join(("image_encoder", key)): value for key, value in vit_state_dict.items()},
**{".".join(("mask_decoder", key)): value for key, value in mask_decoder_state_dict.items()},
**{".".join(("point_encoder", key)): value for key, value in point_encoder_state_dict.items()},
**{".".join(("mask_encoder", key)): value for key, value in mask_encoder_state_dict.items()},
**{f"SAMViTH.{key}": value for key, value in vit_state_dict.items()},
**{f"MaskDecoder.{key}": value for key, value in mask_decoder_state_dict.items()},
**{f"PointEncoder.{key}": value for key, value in point_encoder_state_dict.items()},
**{f"MaskEncoder.{key}": value for key, value in mask_encoder_state_dict.items()},
}
if args.half:
output_state_dict = {key: value.half() for key, value in output_state_dict.items()}

View file

@ -366,6 +366,13 @@ def download_sam():
)
def download_hq_sam():
weights_folder = os.path.join(test_weights_dir)
download_file(
"https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth", weights_folder, expected_hash="66da2472"
)
def download_dinov2():
# For conversion
weights_folder = os.path.join(test_weights_dir)
@ -661,7 +668,16 @@ def convert_sam():
"convert_segment_anything.py",
"tests/weights/sam_vit_h_4b8939.pth",
"tests/weights/segment-anything-h.safetensors",
expected_hash="b62ad5ed",
expected_hash="5ffb976f",
)
def convert_hq_sam():
run_conversion_script(
"convert_hq_segment_anything.py",
"tests/weights/sam_hq_vit_h.pth",
"tests/weights/refiners-sam-hq-vit-h.safetensors",
expected_hash="b2f5e79f",
)
@ -769,6 +785,7 @@ def download_all():
download_ip_adapter()
download_t2i_adapter()
download_sam()
download_hq_sam()
download_dinov2()
download_control_lora_fooocus()
download_lcm_base()
@ -789,6 +806,7 @@ def convert_all():
convert_ip_adapter()
convert_t2i_adapter()
convert_sam()
convert_hq_sam()
convert_dinov2()
convert_control_lora_fooocus()
convert_lcm_base()

View file

@ -0,0 +1,378 @@
import torch
from torch import device as Device, dtype as DType
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters import Adapter
from refiners.fluxion.context import Contexts
from refiners.foundationals.segment_anything.image_encoder import SAMViT, TransformerLayer
from refiners.foundationals.segment_anything.mask_decoder import (
MaskDecoderTokens,
MaskPrediction,
Predictions,
)
from refiners.foundationals.segment_anything.model import SegmentAnything
class CompressViTFeat(fl.Chain):
def __init__(
self,
transformer_dim: int = 256,
vit_dim: int = 1024,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.UseContext(context="hq_sam", key="early_vit_embedding"),
fl.Permute(0, 3, 1, 2),
fl.ConvTranspose2d(
in_channels=vit_dim,
out_channels=transformer_dim,
kernel_size=2,
stride=2,
device=device,
dtype=dtype,
),
fl.LayerNorm2d(channels=transformer_dim, device=device, dtype=dtype),
fl.GeLU(),
fl.ConvTranspose2d(
in_channels=transformer_dim,
out_channels=transformer_dim // 8,
kernel_size=2,
stride=2,
device=device,
dtype=dtype,
),
)
class EmbeddingEncoder(fl.Chain):
def __init__(
self,
transformer_dim: int = 256,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.UseContext(context="mask_decoder", key="image_embedding"),
fl.ConvTranspose2d(
in_channels=transformer_dim,
out_channels=transformer_dim // 4,
kernel_size=2,
stride=2,
device=device,
dtype=dtype,
),
fl.LayerNorm2d(channels=transformer_dim // 4, device=device, dtype=dtype),
fl.GeLU(),
fl.ConvTranspose2d(
in_channels=transformer_dim // 4,
out_channels=transformer_dim // 8,
kernel_size=2,
stride=2,
device=device,
dtype=dtype,
),
)
class HQFeatures(fl.Sum):
def __init__(
self,
vit_dim: int = 1024,
transformer_dim: int = 256,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
EmbeddingEncoder(transformer_dim, device, dtype),
CompressViTFeat(transformer_dim, vit_dim, device, dtype),
)
class EmbeddingMaskfeature(fl.Chain):
def __init__(
self,
transformer_dim: int = 256,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.UseContext(context="mask_decoder", key="upscaled_dense_embedding"),
fl.Reshape(-1, transformer_dim, transformer_dim),
fl.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1, device=device, dtype=dtype),
fl.LayerNorm2d(transformer_dim // 4, device=device, dtype=dtype),
fl.GeLU(),
fl.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1, device=device, dtype=dtype),
)
class DenseEmbeddingUpscalingHQ(fl.Sum):
def __init__(
self,
vit_dim: int = 1024,
transformer_dim: int = 256,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
EmbeddingMaskfeature(transformer_dim, device, dtype),
HQFeatures(vit_dim, transformer_dim, device, dtype),
)
class HQTokenMLP(fl.Chain):
def __init__(
self,
embedding_dim: int,
num_layers: int = 3,
target_num_mask_tokens: int = 5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Slicing(dim=1, start=target_num_mask_tokens, end=target_num_mask_tokens + 1), # HQ token
fl.MultiLinear(
input_dim=embedding_dim,
output_dim=embedding_dim // 8,
inner_dim=embedding_dim,
num_layers=num_layers,
device=device,
dtype=dtype,
),
)
class HQSAMMaskPrediction(fl.Matmul):
def __init__(
self,
embedding_dim: int,
vit_dim: int = 1024,
target_num_mask_tokens: int = 5,
num_layers: int = 3,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
HQTokenMLP(
embedding_dim,
num_layers=num_layers,
target_num_mask_tokens=target_num_mask_tokens,
device=device,
dtype=dtype,
),
fl.Chain(
DenseEmbeddingUpscalingHQ(vit_dim=vit_dim, transformer_dim=256, device=device, dtype=dtype),
fl.Flatten(start_dim=2),
),
)
class MaskPredictionAdapter(fl.Concatenate, Adapter[MaskPrediction]):
def __init__(
self,
target: MaskPrediction,
vit_dim: int = 1024,
target_num_mask_tokens: int = 5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
with self.setup_adapter(target):
super().__init__(
target,
fl.Chain(
HQSAMMaskPrediction(
embedding_dim=target.embedding_dim,
vit_dim=vit_dim,
target_num_mask_tokens=target_num_mask_tokens,
num_layers=3,
device=device,
dtype=dtype,
),
fl.Reshape(-1, target.embedding_dim, target.embedding_dim),
),
dim=1,
)
@property
def hq_sam_mask_prediction(self) -> HQSAMMaskPrediction:
return self.ensure_find(HQSAMMaskPrediction)
class MaskDecoderTokensExtender(fl.Concatenate, Adapter[MaskDecoderTokens]):
"""
Add a new weight to the MaskDecoderTokens to store the new HQ token.
"""
def __init__(
self,
target: MaskDecoderTokens,
) -> None:
self._hq_token = [fl.Parameter(1, target.embedding_dim, device=target.device, dtype=target.dtype)]
with self.setup_adapter(target):
super().__init__(
target,
fl.Chain(
fl.UseContext(context="mask_decoder", key="image_embedding"), # use Context to infer batch size
self.hq_token,
),
dim=1,
)
@property
def regular_tokens(self) -> fl.Parameter:
return self.target.ensure_find(fl.Parameter)
@property
def hq_token(self) -> fl.Parameter:
return self._hq_token[0]
class SAMViTAdapter(fl.Chain, Adapter[SAMViT]):
"""
Add a context to the image encoder transformer to store its early ViT embedding
(first intermediate embedding of the ViT).
"""
def __init__(self, target: SAMViT) -> None:
with self.setup_adapter(target):
super().__init__(target)
target_transformer_layer = self._find_target_transformer_layer()
assert target_transformer_layer is not None
self._transformer_layer = [target_transformer_layer]
self._set_early_vit_embedding_context = [fl.SetContext("hq_sam", "early_vit_embedding")]
@property
def target_transformer_layer(self) -> TransformerLayer:
return self._transformer_layer[0]
@property
def set_early_vit_embedding_context(self) -> fl.SetContext:
return self._set_early_vit_embedding_context[0]
def _find_target_transformer_layer(self) -> TransformerLayer | None:
for transformer_layer in self.target.layers(TransformerLayer):
if transformer_layer.window_size is None:
return transformer_layer
return None
def inject(self: "SAMViTAdapter", parent: fl.Chain | None = None) -> "SAMViTAdapter":
self.target_transformer_layer.append(self.set_early_vit_embedding_context)
return super().inject(parent)
def eject(self) -> None:
self.target_transformer_layer.remove(self.set_early_vit_embedding_context)
super().eject()
class PredictionsPostProc(fl.Module):
def __init__(self, hq_mask_only: bool = False) -> None:
super().__init__()
self.hq_mask_only = hq_mask_only
def forward(
self, masks_predictions: torch.Tensor, iou_predictions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
hq_sam_mask = masks_predictions[:, -1:, ...]
# The official implementation of HQ-SAM has two outputs modes:
# 1. HQ mask only
# 2. HQ mask + base SAM mask, using HQ as a correction to the base SAM mask
# Details can be found in the paper: https://arxiv.org/abs/2306.01567 (section 3.3)
# Heuristics are provided by the authors here: https://github.com/SysCV/sam-hq/blob/3224888/demo/demo_hqsam_pip_example.py#L73-L75
if self.hq_mask_only:
return (hq_sam_mask, iou_predictions)
base_sam_masks = masks_predictions[:, :-1, ...]
assert base_sam_masks.shape[1] == 1
return (hq_sam_mask + base_sam_masks, iou_predictions)
class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
"""Adapter for SAM introducing HQ features.
See [[arXiv:2306.01567] Segment Anything in High Quality](https://arxiv.org/abs/2306.01567) for details.
"""
def init_context(self) -> Contexts:
return {"hq_sam": {"early_vit_embedding": None}}
def __init__(
self,
target: SegmentAnything,
hq_mask_only: bool = False,
weights: dict[str, torch.Tensor] | None = None,
) -> None:
self.vit_embedding_dim = target.image_encoder.embedding_dim
self.target_num_mask_tokens = target.mask_decoder.num_multimask_outputs + 2
with self.setup_adapter(target):
super().__init__(target)
if target.mask_decoder.multimask_output:
raise NotImplementedError("Multi-mask mode is not supported in HQSAMAdapter.")
mask_prediction = target.mask_decoder.ensure_find(MaskPrediction)
self._mask_prediction_adapter = [
MaskPredictionAdapter(
mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype
)
]
self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)]
self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)]
mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens)
self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)]
if weights is not None:
hq_token_prefix = "MaskDecoderTokensExtender.hq_token."
hq_token_state_dict: dict[str, torch.Tensor] = {
k.removeprefix(hq_token_prefix): v for k, v in weights.items() if k.startswith(hq_token_prefix)
}
self.mask_decoder_tokens_extender.hq_token.load_state_dict(hq_token_state_dict)
mask_pred_prefix = "Chain.HQSAMMaskPrediction."
mask_pred_state_dict: dict[str, torch.Tensor] = {
k.removeprefix(mask_pred_prefix): v for k, v in weights.items() if k.startswith(mask_pred_prefix)
}
self.mask_prediction_adapter.hq_sam_mask_prediction.load_state_dict(mask_pred_state_dict)
self.to(device=target.device, dtype=target.dtype)
@property
def mask_decoder_tokens_extender(self) -> MaskDecoderTokensExtender:
return self._mask_decoder_tokens_extender[0]
@property
def mask_prediction_adapter(self) -> MaskPredictionAdapter:
return self._mask_prediction_adapter[0]
@property
def image_encoder_adapter(self) -> SAMViTAdapter:
return self._image_encoder_adapter[0]
@property
def predictions_post_proc(self) -> PredictionsPostProc:
return self._predictions_post_proc[0]
@property
def hq_mask_only(self) -> bool:
return self.predictions_post_proc.hq_mask_only
@hq_mask_only.setter
def hq_mask_only(self, value: bool) -> None:
self.predictions_post_proc.hq_mask_only = value
def inject(self: "HQSAMAdapter", parent: fl.Chain | None = None) -> "HQSAMAdapter":
self.mask_decoder_tokens_extender.inject()
self.mask_prediction_adapter.inject()
self.image_encoder_adapter.inject()
self.target.mask_decoder.insert_after_type(Predictions, self.predictions_post_proc)
return super().inject(parent)
def eject(self) -> None:
self.mask_decoder_tokens_extender.eject()
self.mask_prediction_adapter.eject()
self.image_encoder_adapter.eject()
self.target.mask_decoder.remove(self.predictions_post_proc)
super().eject()

View file

@ -1,5 +1,5 @@
import torch
from torch import Tensor, device as Device, dtype as DType, nn
from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl
from refiners.fluxion.context import Contexts
@ -10,7 +10,7 @@ from refiners.foundationals.segment_anything.transformer import (
class EmbeddingsAggregator(fl.ContextModule):
def forward(self, iou_mask_tokens: Tensor) -> Tensor:
def forward(self, tokens: Tensor) -> Tensor:
mask_decoder = self.ensure_parent
mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder")
image_embedding = mask_decoder_context["image_embedding"]
@ -18,7 +18,7 @@ class EmbeddingsAggregator(fl.ContextModule):
mask_embedding = mask_decoder_context["mask_embedding"]
dense_positional_embedding = mask_decoder_context["dense_positional_embedding"]
sparse_embedding = torch.cat(tensors=(iou_mask_tokens, point_embedding), dim=1)
sparse_embedding = torch.cat(tensors=(tokens, point_embedding), dim=1)
dense_embedding = (image_embedding + mask_embedding).flatten(start_dim=2).transpose(1, 2)
if dense_positional_embedding.shape != dense_embedding.shape:
dense_positional_embedding = dense_positional_embedding.flatten(start_dim=2).transpose(1, 2)
@ -108,10 +108,11 @@ class DenseEmbeddingUpscaling(fl.Chain):
),
fl.GeLU(),
fl.Flatten(start_dim=2),
fl.SetContext(context="mask_decoder", key="upscaled_dense_embedding"),
)
class IOUMaskEncoder(fl.WeightedModule):
class MaskDecoderTokens(fl.Chain):
def __init__(
self,
embedding_dim: int = 256,
@ -119,14 +120,13 @@ class IOUMaskEncoder(fl.WeightedModule):
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.num_mask_tokens = num_mask_tokens
# aka prompt tokens + output token (for IoU scores prediction)
self.weight = nn.Parameter(data=torch.randn(num_mask_tokens + 1, embedding_dim, device=device, dtype=dtype))
def forward(self) -> Tensor:
return self.weight.unsqueeze(dim=0)
# aka output tokens (single-mask output + multi-mask output) + IoU token
super().__init__(
fl.UseContext(context="mask_decoder", key="image_embedding"), # use Context to infer batch size
fl.Parameter(num_mask_tokens + 1, embedding_dim, device=device, dtype=dtype),
)
class MaskPrediction(fl.Chain):
@ -178,7 +178,6 @@ class IOUPrediction(fl.Chain):
super().__init__(
fl.Slicing(dim=1, start=0, end=1),
fl.Squeeze(dim=0),
fl.MultiLinear(
input_dim=embedding_dim,
output_dim=num_mask_tokens,
@ -188,6 +187,39 @@ class IOUPrediction(fl.Chain):
dtype=dtype,
),
fl.Slicing(dim=-1, start=1) if multimask_output else fl.Slicing(dim=-1, start=0, end=1),
fl.Squeeze(dim=1),
)
class Predictions(fl.Parallel):
def __init__(
self,
embedding_dim: int,
num_mask_tokens: int,
multimask_output: bool,
num_layers: int = 3,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_mask_tokens = num_mask_tokens
self.num_layers = num_layers
super().__init__(
MaskPrediction(
embedding_dim=embedding_dim,
num_mask_tokens=num_mask_tokens,
multimask_output=multimask_output,
device=device,
dtype=dtype,
),
IOUPrediction(
embedding_dim=embedding_dim,
num_layers=num_layers,
num_mask_tokens=num_mask_tokens,
multimask_output=multimask_output,
device=device,
dtype=dtype,
),
)
@ -213,7 +245,7 @@ class MaskDecoder(fl.Chain):
num_mask_tokens = self.num_multimask_outputs + 1
super().__init__(
IOUMaskEncoder(embedding_dim=embedding_dim, num_mask_tokens=num_mask_tokens, device=device, dtype=dtype),
MaskDecoderTokens(embedding_dim=embedding_dim, num_mask_tokens=num_mask_tokens, device=device, dtype=dtype),
EmbeddingsAggregator(),
Transformer(
*(
@ -230,22 +262,12 @@ class MaskDecoder(fl.Chain):
SparseCrossDenseAttention(embedding_dim=embedding_dim, device=device, dtype=dtype),
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
),
fl.Parallel(
MaskPrediction(
embedding_dim=embedding_dim,
num_mask_tokens=num_mask_tokens,
multimask_output=multimask_output,
device=device,
dtype=dtype,
),
IOUPrediction(
embedding_dim=embedding_dim,
num_layers=3,
num_mask_tokens=num_mask_tokens,
multimask_output=multimask_output,
device=device,
dtype=dtype,
),
Predictions(
embedding_dim=embedding_dim,
num_mask_tokens=num_mask_tokens,
multimask_output=multimask_output,
device=device,
dtype=dtype,
),
)

View file

@ -20,7 +20,7 @@ class ImageEmbedding:
original_image_size: tuple[int, int] # (height, width)
class SegmentAnything(fl.Module):
class SegmentAnything(fl.Chain):
"""SegmentAnything model.
See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643)
@ -47,16 +47,30 @@ class SegmentAnything(fl.Module):
point_encoder: The point encoder to use.
mask_encoder: The mask encoder to use.
mask_decoder: The mask decoder to use.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__()
self.device: Device = device if isinstance(device, Device) else Device(device=device)
self.dtype = dtype
self.image_encoder = image_encoder.to(device=self.device, dtype=self.dtype)
self.point_encoder = point_encoder.to(device=self.device, dtype=self.dtype)
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder)
self.to(device=device, dtype=dtype)
@property
def image_encoder(self) -> SAMViT:
"""The image encoder."""
return self.ensure_find(SAMViT)
@property
def point_encoder(self) -> PointEncoder:
"""The point encoder."""
return self.ensure_find(PointEncoder)
@property
def mask_encoder(self) -> MaskEncoder:
"""The mask encoder."""
return self.ensure_find(MaskEncoder)
@property
def mask_decoder(self) -> MaskDecoder:
"""The mask decoder."""
return self.ensure_find(MaskDecoder)
@no_grad()
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
@ -259,11 +273,11 @@ class SegmentAnythingH(SegmentAnything):
else:
mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()
super().__init__(
image_encoder=image_encoder,
point_encoder=point_encoder,
mask_encoder=mask_encoder,
mask_decoder=mask_decoder,
device=device,
dtype=dtype,
)
super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder)
self.to(device=device, dtype=dtype)
@property
def image_encoder(self) -> SAMViTH:
"""The image encoder."""
return self.ensure_find(SAMViTH)

View file

@ -98,7 +98,7 @@ class PointEncoder(fl.Chain):
) -> Float[Tensor, "num_positional_features height width"]:
coordinate_encoder = self.ensure_find(layer_type=CoordinateEncoder)
height, width = image_embedding_size
grid = torch.ones((height, width), device=self.device, dtype=torch.float32)
grid = torch.ones((height, width), device=self.device, dtype=self.dtype)
y_embedding = grid.cumsum(dim=0) - 0.5
x_embedding = grid.cumsum(dim=1) - 0.5
y_embedding = y_embedding / height

View file

@ -0,0 +1,18 @@
from pathlib import Path
from warnings import warn
from pytest import fixture, skip
@fixture(scope="package")
def ref_path(test_sam_path: Path) -> Path:
return test_sam_path / "test_sam_ref"
@fixture(scope="package")
def sam_h_weights(test_weights_path: Path) -> Path:
sam_h_weights = test_weights_path / "segment-anything-h.safetensors"
if not sam_h_weights.is_file():
warn(f"could not find weights at {sam_h_weights}, skipping")
skip(allow_module_level=True)
return sam_h_weights

View file

@ -0,0 +1,296 @@
from pathlib import Path
from typing import cast
from warnings import warn
import numpy as np
import pytest
import torch
from PIL import Image
from segment_anything_hq import ( # type: ignore
SamPredictor as SamPredictorHQ,
sam_model_registry as sam_model_registry_hq, # type: ignore
)
from segment_anything_hq.modeling.sam import Sam # type: ignore
from tests.foundationals.segment_anything.utils import FacebookSAM, FacebookSAMPredictorHQ, SAMPrompt
from torch import optim
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
from refiners.foundationals.segment_anything.hq_sam import (
CompressViTFeat,
EmbeddingEncoder,
HQSAMAdapter,
HQTokenMLP,
MaskDecoderTokensExtender,
PredictionsPostProc,
)
from refiners.foundationals.segment_anything.model import SegmentAnythingH
@pytest.fixture(scope="module")
def one_prompt() -> SAMPrompt:
return SAMPrompt(box_points=[[(4, 13), (1007, 1023)]])
@pytest.fixture(scope="module")
def tennis(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "tennis.png").convert("RGB")
@pytest.fixture(scope="module")
def hq_adapter_weights(test_weights_path: Path) -> Path:
"""Path to the HQ adapter weights in Refiners format"""
refiners_hq_adapter_sam_weights = test_weights_path / "refiners-sam-hq-vit-h.safetensors"
if not refiners_hq_adapter_sam_weights.is_file():
warn(f"Test weights not found at {refiners_hq_adapter_sam_weights}, skipping")
pytest.skip(allow_module_level=True)
return refiners_hq_adapter_sam_weights
@pytest.fixture
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
# HQSAMAdapter is designed to be used with single-output only, hence multimask_output=False.
sam_h = SegmentAnythingH(multimask_output=False, device=test_device)
sam_h.load_from_safetensors(tensors_path=sam_h_weights)
return sam_h
@pytest.fixture(scope="module")
def reference_hq_adapter_weights(test_weights_path: Path) -> Path:
"""Path to the HQ adapter weights in default format"""
reference_hq_adapter_sam_weights = test_weights_path / "sam_hq_vit_h.pth"
if not reference_hq_adapter_sam_weights.is_file():
warn(f"Test weights not found at {reference_hq_adapter_sam_weights}, skipping")
pytest.skip(allow_module_level=True)
return reference_hq_adapter_sam_weights
@pytest.fixture(scope="module")
def reference_sam_h(reference_hq_adapter_weights: Path, test_device: torch.device) -> FacebookSAM:
sam_h = cast(FacebookSAM, sam_model_registry_hq["vit_h"](checkpoint=reference_hq_adapter_weights))
return sam_h.to(device=test_device)
@pytest.fixture(scope="module")
def reference_sam_h_predictor(reference_sam_h: FacebookSAM) -> FacebookSAMPredictorHQ:
predictor = SamPredictorHQ(cast(Sam, reference_sam_h))
return cast(FacebookSAMPredictorHQ, predictor)
def test_inject_eject() -> None:
sam_h = SegmentAnythingH(multimask_output=False)
initial_repr = repr(sam_h)
adapter = HQSAMAdapter(sam_h)
assert repr(sam_h) == initial_repr
adapter.inject()
assert repr(sam_h) != initial_repr
adapter.eject()
assert repr(sam_h) == initial_repr
def test_multimask_forbidden() -> None:
with pytest.raises(NotImplementedError, match="not supported"):
HQSAMAdapter(target=SegmentAnythingH(multimask_output=True))
def test_output_shape_hq_adapter(tennis: Image.Image, one_prompt: SAMPrompt) -> None:
sam_h = SegmentAnythingH(multimask_output=False)
HQSAMAdapter(sam_h).inject()
high_res_masks, iou_predictions, low_res_masks = sam_h.predict(tennis, **one_prompt.__dict__)
assert high_res_masks.shape == (1, 1, 1024, 1024)
assert iou_predictions.shape == (1, 1)
assert low_res_masks.shape == (1, 1, 256, 256)
def test_mask_decoder_tokens_extender() -> None:
sam_h = SegmentAnythingH(multimask_output=False)
sam_h.requires_grad_(False)
# MaskDecoderTokens requires image_embedding context to be set
image_embedding = torch.randn(2, 256, 64, 64)
sam_h.mask_decoder.set_image_embedding(image_embedding)
HQSAMAdapter(sam_h).inject()
mask_decoder_tokens = sam_h.ensure_find(MaskDecoderTokensExtender)
tokens_before = mask_decoder_tokens()
assert tokens_before.shape == torch.Size([2, 6, 256])
for p in mask_decoder_tokens.parameters():
match p.shape:
case torch.Size([5, 256]):
assert not p.requires_grad
case torch.Size([1, 256]):
assert p.requires_grad
case _:
raise ValueError
optimizer = optim.SGD(mask_decoder_tokens.parameters(), lr=10)
optimizer.zero_grad()
ones = torch.ones_like(tokens_before)
loss = torch.nn.functional.mse_loss(tokens_before, ones)
loss.backward() # type: ignore
optimizer.step()
tokens_after = mask_decoder_tokens()
assert torch.equal(tokens_before[:, :5, :], tokens_after[:, :5, :])
assert not torch.equal(tokens_before[:, 5, :], tokens_after[:, 5, :])
@no_grad()
def test_early_vit_embedding(
sam_h: SegmentAnythingH,
hq_adapter_weights: Path,
reference_sam_h: FacebookSAM,
tennis: Image.Image,
) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
image_tensor = image_to_tensor(image=tennis.resize(size=(1024, 1024)))
_ = sam_h.image_encoder(image_tensor.to(sam_h.device))
early_vit_embedding_refiners = sam_h.use_context(context_name="hq_sam")["early_vit_embedding"]
_, intermediate_embeddings = reference_sam_h.image_encoder(image_tensor.to(reference_sam_h.device))
early_vit_embedding = intermediate_embeddings[0]
assert torch.equal(early_vit_embedding, early_vit_embedding_refiners)
def test_tokens(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
mask_decoder_tokens_extender = sam_h.mask_decoder.ensure_find(MaskDecoderTokensExtender)
# HF Token (1, 256)
assert torch.equal(reference_sam_h.mask_decoder.hf_token.weight, mask_decoder_tokens_extender.hq_token.weight)
# Regular Tokens (5, 256)
assert torch.equal(
torch.cat([reference_sam_h.mask_decoder.iou_token.weight, reference_sam_h.mask_decoder.mask_tokens.weight]),
mask_decoder_tokens_extender.regular_tokens.weight,
)
@no_grad()
def test_compress_vit_feat(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype)
sam_h.set_context(context="hq_sam", value={"early_vit_embedding": early_vit_embedding})
refiners_output = sam_h.ensure_find(CompressViTFeat)()
reference_output = reference_sam_h.mask_decoder.compress_vit_feat(early_vit_embedding.permute(0, 3, 1, 2))
assert torch.equal(refiners_output, reference_output)
@no_grad()
def test_embedding_encoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
x = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype)
sam_h.set_context(context="mask_decoder", value={"image_embedding": x})
refiners_output = sam_h.ensure_find(EmbeddingEncoder)()
reference_output = reference_sam_h.mask_decoder.embedding_encoder(x)
assert torch.equal(refiners_output, reference_output)
@no_grad()
def test_hq_token_mlp(sam_h: SegmentAnythingH, hq_adapter_weights: Path, reference_sam_h: FacebookSAM) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
x = torch.randn(1, 6, 256, device=sam_h.device, dtype=sam_h.dtype)
refiners_output = sam_h.ensure_find(HQTokenMLP)(x)
reference_output = reference_sam_h.mask_decoder.hf_mlp(x[:, -1, :]).unsqueeze(0)
assert torch.equal(refiners_output, reference_output)
@pytest.mark.parametrize("hq_mask_only", [True, False])
def test_predictor(
sam_h: SegmentAnythingH,
hq_adapter_weights: Path,
hq_mask_only: bool,
reference_sam_h_predictor: FacebookSAMPredictorHQ,
tennis: Image.Image,
one_prompt: SAMPrompt,
) -> None:
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
adapter.hq_mask_only = hq_mask_only
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
# Refiners
high_res_masks, iou_predictions, low_res_masks = sam_h.predict(tennis, **one_prompt.__dict__)
refiners_high_res_mask_hq = high_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
refiners_low_res_mask_hq = low_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
iou_predictions = iou_predictions[0, :].to(dtype=torch.float32).detach().cpu()
# Reference
reference_sam_h_predictor.set_image(np.array(tennis))
predictor_prompt = one_prompt.__dict__["box_points"]
masks_np, iou_predictions_np, low_res_masks_np = reference_sam_h_predictor.predict(
box=np.array(predictor_prompt).flatten(),
multimask_output=False,
hq_token_only=hq_mask_only,
)
reference_high_res_mask_hq = torch.from_numpy(masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
iou_predictions_np = torch.from_numpy(iou_predictions_np).to(dtype=torch.float32) # type: ignore
# NOTE: Diff on logits is relatively high, but on the same scale / even lower than base SAM logits diff (6e-3)
# See https://github.com/finegrain-ai/refiners/blob/c6b5eb24a179d48e4542d94684a70c5ef3142ab1/tests/foundationals/segment_anything/test_sam.py#L426
assert torch.allclose(
reference_low_res_mask_hq,
refiners_low_res_mask_hq,
atol=4e-3,
)
assert (
torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 1
) # The diff on the logits above leads to an absolute diff of 1 pixel on the high res masks
assert torch.allclose(
iou_predictions_np,
torch.max(iou_predictions),
atol=1e-5,
)
@no_grad()
def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
batch_size = 5
image_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
mask_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
dense_positional_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(
batch_size, 1, 1, 1
)
point_embedding = torch.randn(1, 2, 256, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1)
early_vit_embedding = torch.randn(1, 64, 64, 1280, device=sam_h.device, dtype=sam_h.dtype).repeat(
batch_size, 1, 1, 1
)
sam_h.mask_decoder.set_image_embedding(image_embedding)
sam_h.mask_decoder.set_mask_embedding(mask_embedding)
sam_h.mask_decoder.set_point_embedding(point_embedding)
sam_h.mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
sam_h.mask_decoder.set_context(
context="hq_sam", value={"early_vit_embedding": early_vit_embedding.to(sam_h.device, sam_h.dtype)}
)
mask_prediction, iou_prediction = sam_h.mask_decoder()
assert mask_prediction.shape == (batch_size, 1, 256, 256)
assert iou_prediction.shape == (batch_size, 1)
assert torch.equal(mask_prediction[0], mask_prediction[1])

View file

@ -56,15 +56,6 @@ def facebook_sam_h_weights(test_weights_path: Path) -> Path:
return sam_h_weights
@pytest.fixture(scope="module")
def sam_h_weights(test_weights_path: Path) -> Path:
sam_h_weights = test_weights_path / "segment-anything-h.safetensors"
if not sam_h_weights.is_file():
warn(f"could not find weights at {sam_h_weights}, skipping")
pytest.skip(allow_module_level=True)
return sam_h_weights
@pytest.fixture(scope="module")
def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> FacebookSAM:
from segment_anything import build_sam_vit_h # type: ignore
@ -98,11 +89,6 @@ def sam_h_single_output(sam_h_weights: Path, test_device: torch.device) -> Segme
@pytest.fixture(scope="module")
def ref_path(test_sam_path: Path) -> Path:
return test_sam_path / "test_sam_ref"
@pytest.fixture
def truck(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "truck.jpg").convert("RGB")
@ -283,14 +269,14 @@ def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> N
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
assert mapping is not None
mapping["IOUMaskEncoder"] = "iou_token"
mapping["MaskDecoderTokens.Parameter"] = "iou_token"
state_dict = converter._convert_state_dict( # type: ignore
source_state_dict=facebook_mask_decoder.state_dict(),
target_state_dict=refiners_mask_decoder.state_dict(),
state_dict_mapping=mapping,
)
state_dict["IOUMaskEncoder.weight"] = torch.cat(
state_dict["MaskDecoderTokens.Parameter.weight"] = torch.cat(
[facebook_mask_decoder.iou_token.weight, facebook_mask_decoder.mask_tokens.weight], dim=0
) # type: ignore
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
@ -462,3 +448,26 @@ def test_mask_encoder(
assert facebook_mask_input.shape == mask_input.shape
assert torch.allclose(dense_embeddings, fb_dense_embeddings, atol=1e-4, rtol=1e-4)
@no_grad()
def test_batch_mask_decoder(sam_h: SegmentAnythingH) -> None:
batch_size = 5
image_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
mask_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1, 1)
dense_positional_embedding = torch.randn(1, 256, 64, 64, device=sam_h.device, dtype=sam_h.dtype).repeat(
batch_size, 1, 1, 1
)
point_embedding = torch.randn(1, 2, 256, device=sam_h.device, dtype=sam_h.dtype).repeat(batch_size, 1, 1)
sam_h.mask_decoder.set_image_embedding(image_embedding)
sam_h.mask_decoder.set_mask_embedding(mask_embedding)
sam_h.mask_decoder.set_point_embedding(point_embedding)
sam_h.mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
mask_prediction, iou_prediction = sam_h.mask_decoder()
assert mask_prediction.shape == (batch_size, 3, 256, 256)
assert iou_prediction.shape == (batch_size, 3)
assert torch.equal(mask_prediction[0], mask_prediction[1])

View file

@ -1,3 +1,5 @@
# Note about this data
`truck.jpg` is one of the [images](https://github.com/facebookresearch/segment-anything/tree/main/notebooks/images) used in the official [segment-anything notebooks](https://github.com/facebookresearch/segment-anything/tree/main/notebooks).
`tennis.png` is one of the [images](https://github.com/SysCV/sam-hq/tree/main/demo/input_imgs) used in the official [sam-hq demos](https://github.com/SysCV/sam-hq/tree/main/demo).

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 MiB

View file

@ -54,6 +54,23 @@ class FacebookSAMPredictor:
) -> tuple[NDArray, NDArray, NDArray]: ...
class FacebookSAMPredictorHQ:
model: FacebookSAM
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...
def predict(
self,
point_coords: NDArray | None = None,
point_labels: NDArray | None = None,
box: NDArray | None = None,
mask_input: NDArray | None = None,
multimask_output: bool = True,
return_logits: bool = False,
hq_token_only: bool = False,
) -> tuple[NDArray, NDArray, NDArray]: ...
@dataclass
class SAMPrompt:
foreground_points: Sequence[tuple[float, float]] | None = None