add Segment Anything (SAM) to foundational models

Note: dense prompts (i.e. masks) support is still partial (see MaskEncoder)

Co-authored-by: Cédric Deltheil <cedric@deltheil.me>
This commit is contained in:
Benjamin Trom 2023-09-21 10:19:19 +02:00 committed by Cédric Deltheil
parent 2faff9f57a
commit 282578ddc0
14 changed files with 1843 additions and 11 deletions

35
poetry.lock generated
View file

@ -651,21 +651,19 @@ tests = ["asttokens", "littleutils", "pytest", "rich"]
[[package]]
name = "filelock"
version = "3.12.3"
version = "3.12.4"
description = "A platform independent file lock."
optional = false
python-versions = ">=3.8"
files = [
{file = "filelock-3.12.3-py3-none-any.whl", hash = "sha256:f067e40ccc40f2b48395a80fcbd4728262fab54e232e090a4063ab804179efeb"},
{file = "filelock-3.12.3.tar.gz", hash = "sha256:0ecc1dd2ec4672a10c8550a8182f1bd0c0a5088470ecd5a125e45f49472fac3d"},
{file = "filelock-3.12.4-py3-none-any.whl", hash = "sha256:08c21d87ded6e2b9da6728c3dff51baf1dcecf973b768ef35bcbc3447edb9ad4"},
{file = "filelock-3.12.4.tar.gz", hash = "sha256:2e6f249f1f3654291606e046b09f1fd5eac39b360664c27f5aad072012f8bcbd"},
]
[package.dependencies]
typing-extensions = {version = ">=4.7.1", markers = "python_version < \"3.11\""}
[package.extras]
docs = ["furo (>=2023.7.26)", "sphinx (>=7.1.2)", "sphinx-autodoc-typehints (>=1.24)"]
testing = ["covdefaults (>=2.3)", "coverage (>=7.3)", "diff-cover (>=7.7)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "pytest-timeout (>=2.1)"]
typing = ["typing-extensions (>=4.7.1)"]
[[package]]
name = "frozenlist"
@ -2507,6 +2505,25 @@ dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyl
doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"]
test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
[[package]]
name = "segment_anything"
version = "1.0"
description = ""
optional = true
python-versions = "*"
files = []
develop = false
[package.extras]
all = ["matplotlib", "onnx", "onnxruntime", "opencv-python", "pycocotools"]
dev = ["black", "flake8", "isort", "mypy"]
[package.source]
type = "git"
url = "https://github.com/facebookresearch/segment-anything"
reference = "HEAD"
resolved_reference = "6fdee8f2727f4506cfbbe553e23b895e27956588"
[[package]]
name = "sentry-sdk"
version = "1.31.0"
@ -3289,11 +3306,11 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
[extras]
conversion = ["diffusers", "transformers"]
test = ["diffusers", "invisible-watermark", "piq", "transformers"]
conversion = ["diffusers", "segment-anything", "transformers"]
test = ["diffusers", "invisible-watermark", "piq", "segment-anything", "transformers"]
training = ["bitsandbytes", "datasets", "loguru", "prodigyopt", "pydantic", "scipy", "tomli", "torchvision", "wandb"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.12"
content-hash = "b2646e04015b05eba32e6556b4aa37d3c5c31218c09fd0ea53d7540af558b7de"
content-hash = "6e35f01f2fc8611203da972939170d18371182443e4c6f45b4cfc40b4e785dff"

View file

@ -29,11 +29,12 @@ diffusers = {version = "^0.18.0", optional = true}
transformers = {version = "^4.27.4", optional = true}
piq = {version = "^0.7.1", optional = true}
invisible-watermark = {version = "^0.2.0", optional = true}
segment-anything = {git = "https://github.com/facebookresearch/segment-anything", optional = true}
[tool.poetry.extras]
training = ["datasets", "tomli", "wandb", "loguru", "bitsandbytes", "prodigyopt", "pydantic", "scipy", "torchvision"]
conversion = ["diffusers", "transformers"]
test = ["diffusers", "transformers", "piq", "invisible-watermark"]
conversion = ["diffusers", "transformers", "segment-anything"]
test = ["diffusers", "transformers", "piq", "invisible-watermark", "segment-anything"]
[tool.poetry.group.dev.dependencies]
black = "^23.1.0"

View file

@ -0,0 +1,236 @@
import argparse
import types
from typing import Any, Callable, cast
import torch
import torch.nn as nn
from torch import Tensor
import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import manual_seed, save_to_safetensors
from refiners.foundationals.segment_anything.image_encoder import SAMViTH
from refiners.foundationals.segment_anything.prompt_encoder import PointEncoder, MaskEncoder
from segment_anything import build_sam_vit_h # type: ignore
from segment_anything.modeling.common import LayerNorm2d # type: ignore
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
class FacebookSAM(nn.Module):
image_encoder: nn.Module
prompt_encoder: nn.Module
mask_decoder: nn.Module
build_sam_vit_h = cast(Callable[[], FacebookSAM], build_sam_vit_h)
assert issubclass(LayerNorm2d, nn.Module)
custom_layers = {LayerNorm2d: fl.LayerNorm2d}
class Args(argparse.Namespace):
source_path: str
output_path: str
half: bool
verbose: bool
def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
state_dict: dict[str, Tensor] = {
"no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore
}
refiners_mask_encoder = MaskEncoder()
# TODO: handle other weights
refiners_mask_encoder.load_state_dict(state_dict=state_dict, strict=False)
return state_dict
def convert_point_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
manual_seed(seed=0)
point_embeddings: list[Tensor] = [pe.weight for pe in prompt_encoder.point_embeddings] + [prompt_encoder.not_a_point_embed.weight] # type: ignore
pe = prompt_encoder.pe_layer.positional_encoding_gaussian_matrix # type: ignore
assert isinstance(pe, Tensor)
state_dict: dict[str, Tensor] = {
"Residual.Chain.PointTypeEmbedding.weight": nn.Parameter(data=torch.cat(tensors=point_embeddings, dim=0)),
"CoordinateEncoder.Linear.weight": nn.Parameter(data=pe.T.contiguous()),
}
refiners_prompt_encoder = PointEncoder()
refiners_prompt_encoder.load_state_dict(state_dict=state_dict)
return state_dict
def convert_vit(vit: nn.Module) -> dict[str, Tensor]:
manual_seed(seed=0)
refiners_sam_vit_h = SAMViTH()
converter = ModelConverter(
source_model=vit,
target_model=refiners_sam_vit_h,
custom_layer_mapping=custom_layers, # type: ignore
)
converter.skip_init_check = True
x = torch.randn(1, 3, 1024, 1024)
mapping = converter.map_state_dicts(source_args=(x,))
assert mapping
mapping["PositionalEncoder.Chain.Parameter.parameter"] = "pos_embed"
target_state_dict = refiners_sam_vit_h.state_dict()
del target_state_dict["PositionalEncoder.Chain.Parameter.parameter"]
source_state_dict = vit.state_dict()
pos_embed = source_state_dict["pos_embed"]
del source_state_dict["pos_embed"]
target_rel_keys = [
(
f"Transformer.TransformerLayer_{i}.Residual_1.Chain.FusedSelfAttention.RelativePositionAttention.horizontal_embedding",
f"Transformer.TransformerLayer_{i}.Residual_1.Chain.FusedSelfAttention.RelativePositionAttention.vertical_embedding",
)
for i in range(1, 33)
]
source_rel_keys = [(f"blocks.{i}.attn.rel_pos_w", f"blocks.{i}.attn.rel_pos_h") for i in range(32)]
rel_items: dict[str, Tensor] = {}
for (key_w, key_h), (target_key_w, target_key_h) in zip(source_rel_keys, target_rel_keys):
rel_items[target_key_w] = source_state_dict[key_w]
rel_items[target_key_h] = source_state_dict[key_h]
del source_state_dict[key_w]
del source_state_dict[key_h]
del target_state_dict[target_key_w]
del target_state_dict[target_key_h]
converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
)
converted_source["PositionalEncoder.Chain.Parameter.parameter"] = pos_embed # type: ignore
converted_source.update(rel_items)
refiners_sam_vit_h.load_state_dict(state_dict=converted_source)
assert converter.compare_models((x,), threshold=1e-3)
return converted_source
def convert_mask_decoder(mask_decoder: nn.Module) -> dict[str, Tensor]:
manual_seed(seed=0)
refiners_mask_decoder = MaskDecoder()
image_embedding = torch.randn(1, 256, 64, 64)
dense_positional_embedding = torch.randn(1, 256, 64, 64)
point_embedding = torch.randn(1, 3, 256)
mask_embedding = torch.randn(1, 256, 64, 64)
import refiners.fluxion.layers as fl
from segment_anything.modeling.common import LayerNorm2d # type: ignore
assert issubclass(LayerNorm2d, nn.Module)
custom_layers = {LayerNorm2d: fl.LayerNorm2d}
converter = ModelConverter(
source_model=mask_decoder,
target_model=refiners_mask_decoder,
custom_layer_mapping=custom_layers, # type: ignore
)
inputs = {
"image_embeddings": image_embedding,
"image_pe": dense_positional_embedding,
"sparse_prompt_embeddings": point_embedding,
"dense_prompt_embeddings": mask_embedding,
"multimask_output": True,
}
refiners_mask_decoder.set_image_embedding(image_embedding)
refiners_mask_decoder.set_point_embedding(point_embedding)
refiners_mask_decoder.set_mask_embedding(mask_embedding)
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
assert mapping is not None
mapping["IOUMaskEncoder"] = "iou_token"
state_dict = converter._convert_state_dict(source_state_dict=mask_decoder.state_dict(), target_state_dict=refiners_mask_decoder.state_dict(), state_dict_mapping=mapping) # type: ignore
state_dict["IOUMaskEncoder.weight"] = torch.cat(tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0) # type: ignore
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
refiners_mask_decoder.set_image_embedding(image_embedding)
refiners_mask_decoder.set_point_embedding(point_embedding)
refiners_mask_decoder.set_mask_embedding(mask_embedding)
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
# Perform (1) upscaling then (2) mask prediction in this order (= like in the official implementation) to make
# `compare_models` happy (MaskPrediction's Matmul runs those in the reverse order by default)
matmul = refiners_mask_decoder.ensure_find(fl.Matmul)
def forward_swapped_order(self: Any, *args: Any) -> Any:
y = self[1](*args) # (1)
x = self[0](*args) # (2)
return torch.matmul(input=x, other=y)
matmul.forward = types.MethodType(forward_swapped_order, matmul)
assert converter.compare_models(source_args=inputs, target_args={}, threshold=1e-3)
return state_dict
def main() -> None:
parser = argparse.ArgumentParser(description="Converts a Segment Anything ViT model to a Refiners SAMViTH model")
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="sam_vit_h_4b8939.pth",
# required=True,
help="Path to the Segment Anything model weights",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default="segment-anything-h.safetensors",
help="Output path for converted model (as safetensors).",
)
parser.add_argument("--half", action="store_true", default=False, help="Convert to half precision. Default: False")
parser.add_argument(
"--verbose",
action="store_true",
default=False,
help="Prints additional information during conversion. Default: False",
)
args = parser.parse_args(namespace=Args())
sam_h = build_sam_vit_h() # type: ignore
sam_h.load_state_dict(state_dict=torch.load(f=args.source_path)) # type: ignore
vit_state_dict = convert_vit(vit=sam_h.image_encoder)
mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder)
point_encoder_state_dict = convert_point_encoder(prompt_encoder=sam_h.prompt_encoder)
mask_encoder_state_dict = convert_mask_encoder(prompt_encoder=sam_h.prompt_encoder)
output_state_dict = {
**{".".join(("image_encoder", key)): value for key, value in vit_state_dict.items()},
**{".".join(("mask_decoder", key)): value for key, value in mask_decoder_state_dict.items()},
**{".".join(("point_encoder", key)): value for key, value in point_encoder_state_dict.items()},
**{".".join(("mask_encoder", key)): value for key, value in mask_encoder_state_dict.items()},
}
if args.half:
output_state_dict = {key: value.half() for key, value in output_state_dict.items()}
save_to_safetensors(path=args.output_path, tensors=output_state_dict)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,369 @@
from torch import device as Device, dtype as DType, Tensor
from refiners.fluxion.context import Contexts
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import pad
from torch import nn
import torch
class PatchEncoder(fl.Chain):
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int = 16,
use_bias: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_size = patch_size
self.use_bias = use_bias
super().__init__(
fl.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=(self.patch_size, self.patch_size),
stride=(self.patch_size, self.patch_size),
use_bias=self.use_bias,
device=device,
dtype=dtype,
),
fl.Permute(0, 2, 3, 1),
)
class PositionalEncoder(fl.Residual):
def __init__(
self,
embedding_dim: int,
image_embedding_size: tuple[int, int],
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.image_embedding_size = image_embedding_size
super().__init__(
fl.Parameter(
1,
image_embedding_size[0],
image_embedding_size[1],
embedding_dim,
device=device,
dtype=dtype,
),
)
class RelativePositionAttention(fl.WeightedModule):
def __init__(
self,
embedding_dim: int,
num_heads: int,
spatial_size: tuple[int, int],
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.head_dim = embedding_dim // num_heads
self.spatial_size = spatial_size
self.horizontal_embedding = nn.Parameter(
data=torch.zeros(2 * spatial_size[0] - 1, self.head_dim, device=device, dtype=dtype)
)
self.vertical_embedding = nn.Parameter(
data=torch.zeros(2 * spatial_size[1] - 1, self.head_dim, device=device, dtype=dtype)
)
@property
def device(self) -> Device:
return self.horizontal_embedding.device
@property
def dtype(self) -> DType:
return self.horizontal_embedding.dtype
def forward(self, x: Tensor) -> Tensor:
batch, height, width, _ = x.shape
x = (
x.reshape(batch, width * height, 3, self.num_heads, -1)
.permute(2, 0, 3, 1, 4)
.reshape(3, batch * self.num_heads, width * height, -1)
)
query, key, value = x.unbind(dim=0)
horizontal_relative_embedding, vertical_relative_embedding = self.compute_relative_embedding(x=query)
attention = (query * self.head_dim**-0.5) @ key.transpose(dim0=-2, dim1=-1)
# Order of operations is important here
attention = (
(attention.reshape(-1, height, width, height, width) + vertical_relative_embedding)
+ horizontal_relative_embedding
).reshape(attention.shape)
attention = attention.softmax(dim=-1)
attention = attention @ value
attention = (
attention.reshape(batch, self.num_heads, height, width, -1)
.permute(0, 2, 3, 1, 4)
.reshape(batch, height, width, -1)
)
return attention
def compute_relative_coords(self, size: int) -> Tensor:
x, y = torch.meshgrid(torch.arange(end=size), torch.arange(end=size), indexing="ij")
return x - y + size - 1
def compute_relative_embedding(self, x: Tensor) -> tuple[Tensor, Tensor]:
width, height = self.spatial_size
horizontal_coords = self.compute_relative_coords(size=width)
vertical_coords = self.compute_relative_coords(size=height)
horizontal_positional_embedding = self.horizontal_embedding[horizontal_coords]
vertical_positional_embedding = self.vertical_embedding[vertical_coords]
x = x.reshape(x.shape[0], width, height, -1)
horizontal_relative_embedding = torch.einsum("bhwc,wkc->bhwk", x, horizontal_positional_embedding).unsqueeze(
dim=-2
)
vertical_relative_embedding = torch.einsum("bhwc,hkc->bhwk", x, vertical_positional_embedding).unsqueeze(dim=-1)
return horizontal_relative_embedding, vertical_relative_embedding
class FusedSelfAttention(fl.Chain):
def __init__(
self,
embedding_dim: int = 768,
spatial_size: tuple[int, int] = (64, 64),
num_heads: int = 1,
use_bias: bool = True,
is_causal: bool = False,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
assert (
embedding_dim % num_heads == 0
), f"Embedding dim (embedding_dim={embedding_dim}) must be divisible by num heads (num_heads={num_heads})"
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.use_bias = use_bias
self.is_causal = is_causal
super().__init__(
fl.Linear(
in_features=self.embedding_dim,
out_features=3 * self.embedding_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
),
RelativePositionAttention(
embedding_dim=self.embedding_dim,
num_heads=self.num_heads,
spatial_size=spatial_size,
device=device,
dtype=dtype,
),
fl.Linear(
in_features=self.embedding_dim,
out_features=self.embedding_dim,
bias=True,
device=device,
dtype=dtype,
),
)
class FeedForward(fl.Chain):
def __init__(
self,
embedding_dim: int,
feedforward_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim
super().__init__(
fl.Linear(
in_features=self.embedding_dim,
out_features=self.feedforward_dim,
bias=True,
device=device,
dtype=dtype,
),
fl.GeLU(),
fl.Linear(
in_features=self.feedforward_dim,
out_features=self.embedding_dim,
bias=True,
device=device,
dtype=dtype,
),
)
class WindowPartition(fl.ContextModule):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
batch, height, width, channels = x.shape
context = self.use_context(context_name="window_partition")
context.update({"original_height": height, "original_width": width})
window_size = context["window_size"]
padding_height = (window_size - height % window_size) % window_size
padding_width = (window_size - width % window_size) % window_size
if padding_height > 0 or padding_width > 0:
x = pad(x=x, pad=(0, 0, 0, padding_width, 0, padding_height))
padded_height, padded_width = height + padding_height, width + padding_width
context.update({"padded_height": padded_height, "padded_width": padded_width})
x = x.view(batch, padded_height // window_size, window_size, padded_width // window_size, window_size, channels)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, channels)
return windows
class WindowMerge(fl.ContextModule):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
context = self.use_context(context_name="window_partition")
window_size = context["window_size"]
padded_height, padded_width = context["padded_height"], context["padded_width"]
original_height, original_width = context["original_height"], context["original_width"]
batch_size = x.shape[0] // (padded_height * padded_width // window_size // window_size)
x = x.view(batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, padded_height, padded_width, -1)
if padded_height > original_height or padded_width > original_width:
x = x[:, :original_height, :original_width, :].contiguous()
return x
class TransformerLayer(fl.Chain):
def __init__(
self,
embedding_dim: int,
num_heads: int,
feedforward_dim: int,
image_embedding_size: tuple[int, int],
window_size: int | None = None,
layer_norm_eps: float = 1e-6,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.feedforward_dim = feedforward_dim
self.window_size = window_size
self.layer_norm_eps = layer_norm_eps
self.image_embedding_size = image_embedding_size
attention_spatial_size = (window_size, window_size) if window_size is not None else image_embedding_size
reshape_or_merge = (
WindowMerge()
if self.window_size is not None
else fl.Reshape(self.image_embedding_size[0], self.image_embedding_size[1], embedding_dim)
)
super().__init__(
fl.Residual(
fl.LayerNorm(normalized_shape=embedding_dim, eps=self.layer_norm_eps, device=device, dtype=dtype),
WindowPartition() if self.window_size is not None else fl.Identity(),
FusedSelfAttention(
embedding_dim=embedding_dim,
num_heads=num_heads,
spatial_size=attention_spatial_size,
device=device,
dtype=dtype,
),
reshape_or_merge,
),
fl.Residual(
fl.LayerNorm(normalized_shape=embedding_dim, eps=self.layer_norm_eps, device=device, dtype=dtype),
FeedForward(embedding_dim=embedding_dim, feedforward_dim=feedforward_dim, device=device, dtype=dtype),
),
)
def init_context(self) -> Contexts:
return {"window_partition": {"window_size": self.window_size}}
class Neck(fl.Chain):
def __init__(self, in_channels: int = 768, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels
super().__init__(
fl.Permute(0, 3, 1, 2),
fl.Conv2d(
in_channels=self.in_channels,
out_channels=256,
kernel_size=1,
use_bias=False,
device=device,
dtype=dtype,
),
fl.LayerNorm2d(channels=256, device=device, dtype=dtype),
fl.Conv2d(
in_channels=256,
out_channels=256,
kernel_size=3,
padding=1,
use_bias=False,
device=device,
dtype=dtype,
),
fl.LayerNorm2d(channels=256, device=device, dtype=dtype),
)
class Transformer(fl.Chain):
pass
class SAMViT(fl.Chain):
def __init__(
self,
embedding_dim: int,
num_layers: int,
num_heads: int,
global_attention_indices: tuple[int, ...] | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.image_size = (1024, 1024)
self.patch_size = 16
self.window_size = 14
self.image_embedding_size = (self.image_size[0] // self.patch_size, self.image_size[1] // self.patch_size)
self.feed_forward_dim = 4 * self.embedding_dim
self.global_attention_indices = global_attention_indices or tuple()
super().__init__(
PatchEncoder(
in_channels=3, out_channels=embedding_dim, patch_size=self.patch_size, device=device, dtype=dtype
),
PositionalEncoder(
embedding_dim=embedding_dim, image_embedding_size=self.image_embedding_size, device=device, dtype=dtype
),
Transformer(
TransformerLayer(
embedding_dim=embedding_dim,
num_heads=num_heads,
feedforward_dim=self.feed_forward_dim,
window_size=self.window_size if i not in self.global_attention_indices else None,
image_embedding_size=self.image_embedding_size,
device=device,
dtype=dtype,
)
for i in range(num_layers)
),
Neck(in_channels=embedding_dim, device=device, dtype=dtype),
)
class SAMViTH(SAMViT):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
embedding_dim=1280,
num_layers=32,
num_heads=16,
global_attention_indices=(7, 15, 23, 31),
device=device,
dtype=dtype,
)

View file

@ -0,0 +1,264 @@
import refiners.fluxion.layers as fl
from torch import device as Device, dtype as DType, Tensor, nn
import torch
from refiners.foundationals.segment_anything.transformer import (
SparseCrossDenseAttention,
TwoWayTranformerLayer,
)
from refiners.fluxion.context import Contexts
class EmbeddingsAggregator(fl.ContextModule):
def __init__(self, num_output_mask: int = 3) -> None:
super().__init__()
self.num_mask_tokens = num_output_mask
def forward(self, iou_mask_tokens: Tensor) -> Tensor:
mask_decoder = self.ensure_parent
mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder")
image_embedding = mask_decoder_context["image_embedding"]
point_embedding = mask_decoder_context["point_embedding"]
mask_embedding = mask_decoder_context["mask_embedding"]
dense_positional_embedding = mask_decoder_context["dense_positional_embedding"]
sparse_embedding = torch.cat(tensors=(iou_mask_tokens, point_embedding), dim=1)
dense_embedding = (image_embedding + mask_embedding).flatten(start_dim=2).transpose(1, 2)
if dense_positional_embedding.shape != dense_embedding.shape:
dense_positional_embedding = dense_positional_embedding.flatten(start_dim=2).transpose(1, 2)
mask_decoder_context.update(
{
"dense_embedding": dense_embedding,
"dense_positional_embedding": dense_positional_embedding,
"sparse_embedding": sparse_embedding,
}
)
mask_decoder.set_context(context="mask_decoder", value=mask_decoder_context)
return sparse_embedding
class Transformer(fl.Chain):
pass
class Hypernetworks(fl.Concatenate):
def __init__(
self,
embedding_dim: int = 256,
num_layers: int = 3,
num_mask_tokens: int = 3,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.num_layers = num_layers
self.num_mask_tokens = num_mask_tokens
super().__init__(
*[
fl.Chain(
fl.Slicing(dim=1, start=i + 1, length=1),
fl.MultiLinear(
input_dim=embedding_dim,
output_dim=embedding_dim // 8,
inner_dim=embedding_dim,
num_layers=num_layers,
device=device,
dtype=dtype,
),
)
for i in range(num_mask_tokens + 1)
],
dim=1,
)
class DenseEmbeddingUpscaling(fl.Chain):
def __init__(
self,
embedding_dim: int = 256,
dense_embedding_side_dim: int = 64,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.dense_embedding_side_dim = dense_embedding_side_dim
super().__init__(
fl.UseContext(context="mask_decoder", key="dense_embedding"),
fl.Transpose(dim0=1, dim1=2),
fl.Reshape(embedding_dim, dense_embedding_side_dim, dense_embedding_side_dim),
fl.ConvTranspose2d(
in_channels=embedding_dim,
out_channels=embedding_dim // 4,
kernel_size=2,
stride=2,
device=device,
dtype=dtype,
),
fl.LayerNorm2d(channels=embedding_dim // 4, device=device, dtype=dtype),
fl.GeLU(),
fl.ConvTranspose2d(
in_channels=embedding_dim // 4,
out_channels=embedding_dim // 8,
kernel_size=2,
stride=2,
device=device,
dtype=dtype,
),
fl.GeLU(),
fl.Flatten(start_dim=2),
)
class IOUMaskEncoder(fl.WeightedModule):
def __init__(
self,
embedding_dim: int = 256,
num_mask_tokens: int = 4,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.num_mask_tokens = num_mask_tokens
# aka prompt tokens + output token (for IoU scores prediction)
self.weight = nn.Parameter(data=torch.randn(num_mask_tokens + 1, embedding_dim, device=device, dtype=dtype))
def forward(self) -> Tensor:
return self.weight.unsqueeze(dim=0)
class MaskPrediction(fl.Chain):
def __init__(
self,
embedding_dim: int,
num_mask_tokens: int,
num_layers: int = 3,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_mask_tokens = num_mask_tokens
self.num_layers = num_layers
super().__init__(
fl.Matmul(
input=Hypernetworks(
embedding_dim=embedding_dim,
num_layers=num_layers,
num_mask_tokens=num_mask_tokens,
device=device,
dtype=dtype,
),
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
),
fl.Slicing(dim=1, start=1, length=num_mask_tokens),
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim),
)
class IOUPrediction(fl.Chain):
def __init__(
self,
embedding_dim: int,
num_layers: int,
num_mask_tokens: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_layers = num_layers
super().__init__(
fl.Slicing(dim=1, start=0, length=1),
fl.Squeeze(dim=0),
fl.MultiLinear(
input_dim=embedding_dim,
output_dim=num_mask_tokens + 1,
inner_dim=embedding_dim,
num_layers=num_layers,
device=device,
dtype=dtype,
),
fl.Slicing(dim=-1, start=1, length=num_mask_tokens),
)
class MaskDecoder(fl.Chain):
def __init__(
self,
embedding_dim: int = 256,
feed_forward_dim: int = 2048,
num_layers: int = 2,
num_output_mask: int = 3,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.num_mask_tokens = num_output_mask
self.feed_forward_dim = feed_forward_dim
self.num_layers = num_layers
super().__init__(
IOUMaskEncoder(
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask + 1, device=device, dtype=dtype
),
EmbeddingsAggregator(num_output_mask=num_output_mask),
Transformer(
*(
TwoWayTranformerLayer(
embedding_dim=embedding_dim,
num_heads=8,
feed_forward_dim=feed_forward_dim,
use_residual_self_attention=i > 0,
device=device,
dtype=dtype,
)
for i in range(num_layers)
),
SparseCrossDenseAttention(embedding_dim=embedding_dim, device=device, dtype=dtype),
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
),
fl.Parallel(
MaskPrediction(
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask, device=device, dtype=dtype
),
IOUPrediction(
embedding_dim=embedding_dim,
num_layers=3,
num_mask_tokens=num_output_mask,
device=device,
dtype=dtype,
),
),
)
def init_context(self) -> Contexts:
return {
"mask_decoder": {
"image_embedding": None,
"point_embedding": None,
"mask_embedding": None,
"dense_positional_embedding": None,
}
}
def set_image_embedding(self, image_embedding: Tensor) -> None:
mask_decoder_context = self.use_context(context_name="mask_decoder")
mask_decoder_context["image_embedding"] = image_embedding
def set_point_embedding(self, point_embedding: Tensor) -> None:
mask_decoder_context = self.use_context(context_name="mask_decoder")
mask_decoder_context["point_embedding"] = point_embedding
def set_mask_embedding(self, mask_embedding: Tensor) -> None:
mask_decoder_context = self.use_context(context_name="mask_decoder")
mask_decoder_context["mask_embedding"] = mask_embedding
def set_dense_positional_embedding(self, dense_positional_embedding: Tensor) -> None:
mask_decoder_context = self.use_context(context_name="mask_decoder")
mask_decoder_context["dense_positional_embedding"] = dense_positional_embedding

View file

@ -0,0 +1,166 @@
from dataclasses import dataclass
from typing import Sequence
from PIL import Image
from torch import device as Device, dtype as DType, Tensor
import numpy as np
import torch
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import image_to_tensor, normalize, pad, interpolate
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
@dataclass
class ImageEmbedding:
features: Tensor
original_image_size: tuple[int, int] # (height, width)
class SegmentAnything(fl.Module):
mask_threshold: float = 0.0
def __init__(
self,
image_encoder: SAMViT,
point_encoder: PointEncoder,
mask_encoder: MaskEncoder,
mask_decoder: MaskDecoder,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
super().__init__()
self.device: Device = device if isinstance(device, Device) else Device(device=device)
self.dtype = dtype
self.image_encoder = image_encoder.to(device=self.device, dtype=self.dtype)
self.point_encoder = point_encoder.to(device=self.device, dtype=self.dtype)
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
original_size = (image.height, image.width)
target_size = self.compute_target_size(original_size)
return ImageEmbedding(
features=self.image_encoder(self.preprocess_image(image=image, target_size=target_size)),
original_image_size=original_size,
)
def predict(
self,
input: Image.Image | ImageEmbedding,
foreground_points: Sequence[tuple[float, float]] | None = None,
background_points: Sequence[tuple[float, float]] | None = None,
box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
masks: Sequence[Image.Image] | None = None,
binarize: bool = True,
) -> tuple[Tensor, Tensor, Tensor]:
if isinstance(input, ImageEmbedding):
original_size = input.original_image_size
target_size = self.compute_target_size(original_size)
image_embedding = input.features
else:
original_size = (input.height, input.width)
target_size = self.compute_target_size(original_size)
image_embedding = self.image_encoder(self.preprocess_image(image=input, target_size=target_size))
coordinates, type_mask = self.point_encoder.points_to_tensor(
foreground_points=foreground_points,
background_points=background_points,
box_points=box_points,
)
self.point_encoder.set_type_mask(type_mask=type_mask)
if masks is not None:
mask_tensor = torch.stack(
tensors=[image_to_tensor(image=mask, device=self.device, dtype=self.dtype) for mask in masks]
)
mask_embedding = self.mask_encoder(mask_tensor)
else:
mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(
image_embedding_size=self.image_encoder.image_embedding_size
)
point_embedding = self.point_encoder(
self.normalize(coordinates, target_size=target_size, original_size=original_size)
)
dense_positional_embedding = self.point_encoder.get_dense_positional_embedding(
image_embedding_size=self.image_encoder.image_embedding_size
)
self.mask_decoder.set_image_embedding(image_embedding=image_embedding)
self.mask_decoder.set_mask_embedding(mask_embedding=mask_embedding)
self.mask_decoder.set_point_embedding(point_embedding=point_embedding)
self.mask_decoder.set_dense_positional_embedding(dense_positional_embedding=dense_positional_embedding)
low_res_masks, iou_predictions = self.mask_decoder()
high_res_masks = self.postprocess_masks(
masks=low_res_masks, target_size=target_size, original_size=original_size
)
if binarize:
high_res_masks = high_res_masks > self.mask_threshold
return high_res_masks, iou_predictions, low_res_masks
@property
def image_size(self) -> int:
w, h = self.image_encoder.image_size
assert w == h
return w
def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]:
oldh, oldw = size
scale = self.image_size * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)
def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor:
h, w = target_size
padh = self.image_size - h
padw = self.image_size - w
image_tensor = torch.tensor(
np.array(image.resize((w, h), resample=Image.Resampling.BILINEAR)).astype(np.float32).transpose(2, 0, 1),
device=self.device,
dtype=self.dtype,
).unsqueeze(0)
return pad(
normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]), (0, padw, 0, padh)
)
def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size
coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size
return coordinates
def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
masks = interpolate(masks, factor=torch.Size((self.image_size, self.image_size)), mode="bilinear")
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
masks = interpolate(masks, factor=torch.Size(original_size), mode="bilinear")
return masks
class SegmentAnythingH(SegmentAnything):
def __init__(
self,
image_encoder: SAMViTH | None = None,
point_encoder: PointEncoder | None = None,
mask_encoder: MaskEncoder | None = None,
mask_decoder: MaskDecoder | None = None,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
image_encoder = image_encoder or SAMViTH()
point_encoder = point_encoder or PointEncoder()
mask_encoder = mask_encoder or MaskEncoder()
mask_decoder = mask_decoder or MaskDecoder()
super().__init__(
image_encoder=image_encoder,
point_encoder=point_encoder,
mask_encoder=mask_encoder,
mask_decoder=mask_decoder,
device=device,
dtype=dtype,
)

View file

@ -0,0 +1,190 @@
from enum import Enum, auto
from collections.abc import Sequence
from torch import device as Device, dtype as DType, Tensor, nn
import torch
from jaxtyping import Float, Int
import refiners.fluxion.layers as fl
from refiners.fluxion.context import Contexts
class CoordinateEncoder(fl.Chain):
def __init__(
self,
num_positional_features: int = 64,
scale: float = 1,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.num_positional_features = num_positional_features
self.scale = scale
super().__init__(
fl.Multiply(scale=2, bias=-1),
fl.Linear(in_features=2, out_features=num_positional_features, bias=False, device=device, dtype=dtype),
fl.Multiply(scale=2 * torch.pi * self.scale),
fl.Concatenate(fl.Sin(), fl.Cos(), dim=-1),
)
class PointType(Enum):
BACKGROUND = auto()
FOREGROUND = auto()
BOX_TOP_LEFT = auto()
BOX_BOTTOM_RIGHT = auto()
NOT_A_POINT = auto()
class PointTypeEmbedding(fl.WeightedModule, fl.ContextModule):
def __init__(self, embedding_dim: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.weight = nn.Parameter(data=torch.randn(len(PointType), self.embedding_dim, device=device, dtype=dtype))
def forward(self, type_mask: Int[Tensor, "1 num_points"]) -> Float[Tensor, "1 num_points embedding_dim"]:
assert isinstance(type_mask, Tensor), "type_mask must be a Tensor."
embeddings = torch.zeros(*type_mask.shape, self.embedding_dim).to(device=type_mask.device)
for type_id in PointType:
mask = type_mask == type_id.value
embeddings[mask] = self.weight[type_id.value - 1]
return embeddings
class PointEncoder(fl.Chain):
def __init__(
self, embedding_dim: int = 256, scale: float = 1, device: Device | str | None = None, dtype: DType | None = None
) -> None:
assert embedding_dim % 2 == 0, "embedding_dim must be divisible by 2."
self.embedding_dim = embedding_dim
self.scale = scale
super().__init__(
CoordinateEncoder(num_positional_features=embedding_dim // 2, scale=scale, device=device, dtype=dtype),
fl.Lambda(func=self.pad),
fl.Residual(
fl.UseContext(context="point_encoder", key="type_mask"),
PointTypeEmbedding(embedding_dim=embedding_dim, device=device, dtype=dtype),
),
)
def pad(self, x: Tensor) -> Tensor:
type_mask: Tensor = self.use_context("point_encoder")["type_mask"]
if torch.any((type_mask == PointType.BOX_TOP_LEFT.value) | (type_mask == PointType.BOX_BOTTOM_RIGHT.value)):
# Some boxes have been passed: no need to pad in this case
return x
type_mask = torch.cat(
[type_mask, torch.full((type_mask.shape[0], 1), PointType.NOT_A_POINT.value, device=type_mask.device)],
dim=1,
)
self.set_context(context="point_encoder", value={"type_mask": type_mask})
return torch.cat([x, torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device)], dim=1)
def init_context(self) -> Contexts:
return {
"point_encoder": {
"type_mask": None,
}
}
def set_type_mask(self, type_mask: Int[Tensor, "1 num_points"]) -> None:
self.set_context(context="point_encoder", value={"type_mask": type_mask})
def get_dense_positional_embedding(
self, image_embedding_size: tuple[int, int]
) -> Float[Tensor, "num_positional_features height width"]:
coordinate_encoder = self.ensure_find(layer_type=CoordinateEncoder)
height, width = image_embedding_size
grid = torch.ones((height, width), device=self.device, dtype=torch.float32)
y_embedding = grid.cumsum(dim=0) - 0.5
x_embedding = grid.cumsum(dim=1) - 0.5
y_embedding = y_embedding / height
x_embedding = x_embedding / width
positional_embedding = (
coordinate_encoder(torch.stack(tensors=[x_embedding, y_embedding], dim=-1))
.permute(2, 0, 1)
.unsqueeze(dim=0)
)
return positional_embedding
def points_to_tensor(
self,
foreground_points: Sequence[tuple[float, float]] | None = None,
background_points: Sequence[tuple[float, float]] | None = None,
not_a_points: Sequence[tuple[float, float]] | None = None,
box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
) -> tuple[Float[Tensor, "1 num_points 2"], Int[Tensor, "1 num_points"]]:
foreground_points = foreground_points or []
background_points = background_points or []
not_a_points = not_a_points or []
box_points = box_points or []
top_left_points = [box[0] for box in box_points]
bottom_right_points = [box[1] for box in box_points]
coordinates: list[Tensor] = []
type_ids: list[Tensor] = []
# Must be in sync with PointType enum
for type_id, coords_seq in zip(
PointType, [background_points, foreground_points, top_left_points, bottom_right_points, not_a_points]
):
if len(coords_seq) > 0:
coords_tensor = torch.tensor(data=list(coords_seq), dtype=torch.float, device=self.device)
coordinates.append(coords_tensor)
point_ids = torch.tensor(data=[type_id.value] * len(coords_seq), dtype=torch.int, device=self.device)
type_ids.append(point_ids)
all_coordinates = torch.cat(tensors=coordinates, dim=0).unsqueeze(dim=0)
type_mask = torch.cat(tensors=type_ids, dim=0).unsqueeze(dim=0)
return all_coordinates, type_mask
class MaskEncoder(fl.Chain):
def __init__(
self,
embedding_dim: int = 256,
intermediate_channels: int = 16,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.intermediate_channels = intermediate_channels
super().__init__(
fl.Conv2d(
in_channels=1,
out_channels=self.intermediate_channels // 4,
kernel_size=2,
stride=2,
device=device,
dtype=dtype,
),
fl.LayerNorm2d(channels=self.intermediate_channels // 4, device=device, dtype=dtype),
fl.GeLU(),
fl.Conv2d(
in_channels=self.intermediate_channels // 4,
out_channels=self.intermediate_channels,
kernel_size=2,
stride=2,
device=device,
dtype=dtype,
),
fl.LayerNorm2d(channels=self.intermediate_channels, device=device, dtype=dtype),
fl.GeLU(),
fl.Conv2d(
in_channels=self.intermediate_channels,
out_channels=self.embedding_dim,
kernel_size=1,
device=device,
dtype=dtype,
),
)
self.register_parameter(
"no_mask_embedding", nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype))
)
def get_no_mask_dense_embedding(
self, image_embedding_size: tuple[int, int], batch_size: int = 1
) -> Float[Tensor, "batch embedding_dim image_embedding_height image_embedding_width"]:
return self.no_mask_embedding.reshape(1, -1, 1, 1).expand(
batch_size, -1, image_embedding_size[0], image_embedding_size[1]
)

View file

@ -0,0 +1,157 @@
from torch import dtype as DType, device as Device
import refiners.fluxion.layers as fl
class CrossAttention(fl.Attention):
def __init__(
self,
embedding_dim: int,
cross_embedding_dim: int | None = None,
num_heads: int = 1,
inner_dim: int | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
embedding_dim=embedding_dim,
key_embedding_dim=cross_embedding_dim,
num_heads=num_heads,
inner_dim=inner_dim,
is_optimized=False,
device=device,
dtype=dtype,
)
self.cross_embedding_dim = cross_embedding_dim or embedding_dim
self.insert(index=0, module=fl.Parallel(fl.GetArg(index=0), fl.GetArg(index=1), fl.GetArg(index=1)))
class FeedForward(fl.Residual):
def __init__(
self, embedding_dim: int, feed_forward_dim: int, device: Device | str | None = None, dtype: DType | None = None
) -> None:
self.embedding_dim = embedding_dim
self.feed_forward_dim = feed_forward_dim
super().__init__(
fl.Linear(in_features=embedding_dim, out_features=feed_forward_dim, device=device, dtype=dtype),
fl.ReLU(),
fl.Linear(in_features=feed_forward_dim, out_features=embedding_dim, device=device, dtype=dtype),
)
class SparseSelfAttention(fl.Residual):
def __init__(
self,
embedding_dim: int,
inner_dim: int | None = None,
num_heads: int = 1,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
add_sparse_embedding = fl.Residual(fl.UseContext(context="mask_decoder", key="sparse_embedding"))
super().__init__(
fl.Parallel(add_sparse_embedding, add_sparse_embedding, fl.Identity()),
fl.Attention(
embedding_dim=embedding_dim,
inner_dim=inner_dim,
num_heads=num_heads,
is_optimized=False,
device=device,
dtype=dtype,
),
)
class SparseCrossDenseAttention(fl.Residual):
def __init__(
self, embedding_dim: int, num_heads: int = 8, device: Device | str | None = None, dtype: DType | None = None
) -> None:
self.embedding_dim = embedding_dim
self.num_heads = num_heads
super().__init__(
fl.Parallel(
fl.Residual(
fl.UseContext(context="mask_decoder", key="sparse_embedding"),
),
fl.Sum(
fl.UseContext(context="mask_decoder", key="dense_embedding"),
fl.UseContext(context="mask_decoder", key="dense_positional_embedding"),
),
fl.UseContext(context="mask_decoder", key="dense_embedding"),
),
fl.Attention(
embedding_dim=embedding_dim,
inner_dim=embedding_dim // 2,
num_heads=num_heads,
is_optimized=False,
device=device,
dtype=dtype,
),
)
class DenseCrossSparseAttention(fl.Chain):
def __init__(
self, embedding_dim: int, num_heads: int = 8, device: Device | str | None = None, dtype: DType | None = None
) -> None:
super().__init__(
fl.Parallel(
fl.Sum(
fl.UseContext(context="mask_decoder", key="dense_embedding"),
fl.UseContext(context="mask_decoder", key="dense_positional_embedding"),
),
fl.Residual(
fl.UseContext(context="mask_decoder", key="sparse_embedding"),
),
fl.Identity(),
),
fl.Attention(
embedding_dim=embedding_dim,
inner_dim=embedding_dim // 2,
num_heads=num_heads,
is_optimized=False,
device=device,
dtype=dtype,
),
)
class TwoWayTranformerLayer(fl.Chain):
def __init__(
self,
embedding_dim: int,
num_heads: int = 8,
feed_forward_dim: int = 2048,
use_residual_self_attention: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.feed_forward_dim = feed_forward_dim
self_attention = (
SparseSelfAttention(embedding_dim=embedding_dim, num_heads=num_heads, device=device, dtype=dtype)
if use_residual_self_attention
else fl.SelfAttention(
embedding_dim=embedding_dim, num_heads=num_heads, is_optimized=False, device=device, dtype=dtype
)
)
super().__init__(
self_attention,
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
SparseCrossDenseAttention(embedding_dim=embedding_dim, num_heads=num_heads, device=device, dtype=dtype),
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
FeedForward(embedding_dim=embedding_dim, feed_forward_dim=feed_forward_dim, device=device, dtype=dtype),
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
fl.Passthrough(
fl.Sum(
fl.UseContext(context="mask_decoder", key="dense_embedding"),
DenseCrossSparseAttention(
embedding_dim=embedding_dim, num_heads=num_heads, device=device, dtype=dtype
),
),
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
fl.SetContext(context="mask_decoder", key="dense_embedding"),
),
)

View file

@ -28,3 +28,8 @@ def test_e2e_path() -> Path:
@fixture(scope="session")
def test_textual_inversion_path() -> Path:
return PARENT_PATH / "foundationals" / "clip" / "test_concepts_ref"
@fixture(scope="session")
def test_sam_path() -> Path:
return PARENT_PATH / "foundationals" / "segment_anything"

View file

@ -0,0 +1,317 @@
from math import isclose
from pathlib import Path
from typing import cast
from warnings import warn
import pytest
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from torch import Tensor
from refiners.fluxion import manual_seed
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import image_to_tensor
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
from refiners.foundationals.segment_anything.model import SegmentAnythingH
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
from tests.foundationals.segment_anything.utils import (
FacebookSAM,
FacebookSAMPredictor,
SAMPrompt,
intersection_over_union,
)
# See predictor_example.ipynb official notebook (note: mask_input is not yet properly supported)
PROMPTS: list[SAMPrompt] = [
SAMPrompt(foreground_points=((500, 375),)),
SAMPrompt(background_points=((500, 375),)),
SAMPrompt(foreground_points=((500, 375), (1125, 625))),
SAMPrompt(foreground_points=((500, 375),), background_points=((1125, 625),)),
SAMPrompt(box_points=[[(425, 600), (700, 875)]]),
SAMPrompt(box_points=[[(425, 600), (700, 875)]], background_points=((575, 750),)),
]
@pytest.fixture(params=PROMPTS)
def prompt(request: pytest.FixtureRequest) -> SAMPrompt:
return request.param
@pytest.fixture
def one_prompt() -> SAMPrompt:
return PROMPTS[0]
@pytest.fixture(scope="module")
def facebook_sam_h_weights(test_weights_path: Path) -> Path:
sam_h_weights = test_weights_path / "sam_vit_h_4b8939.pth"
if not sam_h_weights.is_file():
warn(f"could not find weights at {sam_h_weights}, skipping")
pytest.skip(allow_module_level=True)
return sam_h_weights
@pytest.fixture(scope="module")
def sam_h_weights(test_weights_path: Path) -> Path:
sam_h_weights = test_weights_path / "segment-anything-h.safetensors"
if not sam_h_weights.is_file():
warn(f"could not find weights at {sam_h_weights}, skipping")
pytest.skip(allow_module_level=True)
return sam_h_weights
@pytest.fixture(scope="module")
def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> FacebookSAM:
from segment_anything import build_sam_vit_h # type: ignore
sam_h = cast(FacebookSAM, build_sam_vit_h())
sam_h.load_state_dict(state_dict=torch.load(f=facebook_sam_h_weights)) # type: ignore
return sam_h.to(device=test_device)
@pytest.fixture(scope="module")
def facebook_sam_h_predictor(facebook_sam_h: FacebookSAM) -> FacebookSAMPredictor:
from segment_anything import SamPredictor # type: ignore
from segment_anything.modeling import Sam # type: ignore
predictor = SamPredictor(cast(Sam, facebook_sam_h))
return cast(FacebookSAMPredictor, predictor)
@pytest.fixture(scope="module")
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
sam_h = SegmentAnythingH(device=test_device)
# TODO: make strict=True when the MasKEncoder conversion is done
sam_h.load_from_safetensors(tensors_path=sam_h_weights, strict=False)
return sam_h
@pytest.fixture(scope="module")
def ref_path(test_sam_path: Path) -> Path:
return test_sam_path / "test_sam_ref"
@pytest.fixture
def truck(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "truck.jpg").convert("RGB")
@torch.no_grad()
def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
manual_seed(seed=0)
x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device)
attention = cast(nn.Module, facebook_sam_h.image_encoder.blocks[0].attn) # type: ignore
refiners_attention = FusedSelfAttention(
embedding_dim=1280, num_heads=16, spatial_size=(14, 14), device=facebook_sam_h.device
)
refiners_attention.Linear_1.weight = attention.qkv.weight # type: ignore
refiners_attention.Linear_1.bias = attention.qkv.bias # type: ignore
refiners_attention.Linear_2.weight = attention.proj.weight # type: ignore
refiners_attention.Linear_2.bias = attention.proj.bias # type: ignore
refiners_attention.RelativePositionAttention.horizontal_embedding = attention.rel_pos_w
refiners_attention.RelativePositionAttention.vertical_embedding = attention.rel_pos_h
y_1 = attention(x)
assert y_1.shape == x.shape
y_2 = refiners_attention(x)
assert y_2.shape == x.shape
assert torch.equal(input=y_1, other=y_2)
@torch.no_grad()
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device)
y_1 = facebook_sam_h.image_encoder(image_tensor)
y_2 = sam_h.image_encoder(image_tensor)
assert torch.equal(input=y_1, other=y_2)
@torch.no_grad()
def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
refiners_prompt_encoder = sam_h.point_encoder
facebook_dense_pe: Tensor = cast(Tensor, facebook_prompt_encoder.get_dense_pe()) # type: ignore
refiners_dense_pe = refiners_prompt_encoder.get_dense_positional_embedding(image_embedding_size=(64, 64))
assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)
@torch.no_grad()
def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
refiners_prompt_encoder = sam_h.mask_encoder
_, facebook_dense_pe = facebook_prompt_encoder(points=None, boxes=None, masks=None)
refiners_dense_pe = refiners_prompt_encoder.get_no_mask_dense_embedding(image_embedding_size=(64, 64))
assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)
@torch.no_grad()
def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, prompt: SAMPrompt) -> None:
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
refiners_prompt_encoder = sam_h.point_encoder
facebook_sparse_pe, _ = facebook_prompt_encoder(
**prompt.facebook_prompt_encoder_kwargs(device=facebook_sam_h.device)
)
coordinates, type_mask = refiners_prompt_encoder.points_to_tensor(**prompt.__dict__)
# Shift to center of pixel + normalize in [0, 1] (see `_embed_points` in segment-anything official repo)
coordinates[:, :, 0] = (coordinates[:, :, 0] + 0.5) / 1024.0
coordinates[:, :, 1] = (coordinates[:, :, 1] + 0.5) / 1024.0
refiners_prompt_encoder.set_type_mask(type_mask=type_mask)
refiners_sparse_pe = refiners_prompt_encoder(coordinates)
assert torch.equal(input=refiners_sparse_pe, other=facebook_sparse_pe)
@torch.no_grad()
def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
dense_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
sparse_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device)
refiners_layer = TwoWayTranformerLayer(
embedding_dim=256, feed_forward_dim=2048, num_heads=8, device=facebook_sam_h.device
)
facebook_layer = facebook_sam_h.mask_decoder.transformer.layers[1] # type: ignore
assert isinstance(facebook_layer, nn.Module)
refiners_layer.set_context(
context="mask_decoder",
value={
"dense_embedding": dense_embedding,
"dense_positional_embedding": dense_positional_embedding,
"sparse_embedding": sparse_embedding,
},
)
facebook_inputs = {
"queries": sparse_embedding,
"keys": dense_embedding,
"query_pe": sparse_embedding,
"key_pe": dense_positional_embedding,
}
converter = ModelConverter(
source_model=facebook_layer,
target_model=refiners_layer,
skip_output_check=True, # done below, manually
)
assert converter.run(source_args=facebook_inputs, target_args=(sparse_embedding,))
refiners_layer.set_context(
context="mask_decoder",
value={
"dense_embedding": dense_embedding,
"dense_positional_embedding": dense_positional_embedding,
"sparse_embedding": sparse_embedding,
},
)
y_1 = facebook_layer(**facebook_inputs)[0]
y_2 = refiners_layer(sparse_embedding)[0]
assert torch.equal(input=y_1, other=y_2)
@torch.no_grad()
def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
manual_seed(seed=0)
facebook_mask_decoder = facebook_sam_h.mask_decoder
refiners_mask_decoder = sam_h.mask_decoder
image_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device)
dense_positional_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device)
point_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device)
mask_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device)
import refiners.fluxion.layers as fl
from segment_anything.modeling.common import LayerNorm2d # type: ignore
assert issubclass(LayerNorm2d, nn.Module)
custom_layers = {LayerNorm2d: fl.LayerNorm2d}
converter = ModelConverter(
source_model=facebook_mask_decoder,
target_model=refiners_mask_decoder,
custom_layer_mapping=custom_layers, # type: ignore
)
inputs = {
"image_embeddings": image_embedding,
"image_pe": dense_positional_embedding,
"sparse_prompt_embeddings": point_embedding,
"dense_prompt_embeddings": mask_embedding,
"multimask_output": True,
}
refiners_mask_decoder.set_image_embedding(image_embedding)
refiners_mask_decoder.set_point_embedding(point_embedding)
refiners_mask_decoder.set_mask_embedding(mask_embedding)
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
assert mapping is not None
mapping["IOUMaskEncoder"] = "iou_token"
state_dict = converter._convert_state_dict(source_state_dict=facebook_mask_decoder.state_dict(), target_state_dict=refiners_mask_decoder.state_dict(), state_dict_mapping=mapping) # type: ignore
state_dict["IOUMaskEncoder.weight"] = torch.cat([facebook_mask_decoder.iou_token.weight, facebook_mask_decoder.mask_tokens.weight], dim=0) # type: ignore
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
facebook_output = facebook_mask_decoder(**inputs)
refiners_mask_decoder.set_image_embedding(image_embedding)
refiners_mask_decoder.set_point_embedding(point_embedding)
refiners_mask_decoder.set_mask_embedding(mask_embedding)
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
mask_prediction, iou_prediction = refiners_mask_decoder()
facebook_masks = facebook_output[0]
facebook_prediction = facebook_output[1]
assert torch.equal(input=mask_prediction, other=facebook_masks)
assert torch.equal(input=iou_prediction, other=facebook_prediction)
@torch.no_grad()
def test_predictor(
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, prompt: SAMPrompt
) -> None:
predictor = facebook_sam_h_predictor
predictor.set_image(np.array(truck))
facebook_masks, facebook_scores, _ = predictor.predict(**prompt.facebook_predict_kwargs()) # type: ignore
assert len(facebook_masks) == 3
masks, scores, _ = sam_h.predict(truck, **prompt.__dict__)
masks = masks.squeeze(0)
scores = scores.squeeze(0)
assert len(masks) == 3
for i in range(3):
mask_prediction = masks[i].cpu()
facebook_mask = torch.as_tensor(facebook_masks[i])
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05)
@torch.no_grad()
def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None:
masks_ref, scores_ref, _ = sam_h.predict(truck, **one_prompt.__dict__)
image_embedding = sam_h.compute_image_embedding(truck)
masks, scores, _ = sam_h.predict(image_embedding, **one_prompt.__dict__)
assert torch.equal(masks, masks_ref)
assert torch.equal(scores_ref, scores)

View file

@ -0,0 +1,3 @@
# Note about this data
`truck.jpg` is one of the [images](https://github.com/facebookresearch/segment-anything/tree/main/notebooks/images) used in the official [segment-anything notebooks](https://github.com/facebookresearch/segment-anything/tree/main/notebooks).

Binary file not shown.

After

Width:  |  Height:  |  Size: 265 KiB

View file

@ -0,0 +1,107 @@
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, TypedDict
from jaxtyping import Bool
from torch import Tensor, nn
import numpy as np
import numpy.typing as npt
import torch
NDArrayUInt8 = npt.NDArray[np.uint8]
NDArray = npt.NDArray[Any]
class SAMInput(TypedDict):
image: Tensor
original_size: tuple[int, int]
point_coords: Tensor | None
point_labels: Tensor | None
boxes: Tensor | None
mask_inputs: Tensor | None
class SAMOutput(TypedDict):
masks: Tensor
iou_predictions: Tensor
low_res_logits: Tensor
class FacebookSAM(nn.Module):
image_encoder: nn.Module
prompt_encoder: nn.Module
mask_decoder: nn.Module
def __call__(self, batched_input: list[SAMInput], multimask_output: bool) -> list[SAMOutput]: ...
@property
def device(self) -> Any: ...
class FacebookSAMPredictor:
model: FacebookSAM
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...
def predict(
self,
point_coords: NDArray | None = None,
point_labels: NDArray | None = None,
box: NDArray | None = None,
mask_input: NDArray | None = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> tuple[NDArray, NDArray, NDArray]: ...
@dataclass
class SAMPrompt:
foreground_points: Sequence[tuple[float, float]] | None = None
background_points: Sequence[tuple[float, float]] | None = None
box_points: Sequence[Sequence[tuple[float, float]]] | None = None
# TODO: support masks
# masks: Sequence[Image.Image] | None = None
def facebook_predict_kwargs(self) -> dict[str, NDArray]:
prompt: dict[str, NDArray] = {}
# Note: the order matters since `points_to_tensor` processes points that way (background -> foreground -> etc)
if self.background_points:
prompt["point_coords"] = np.array(self.background_points)
prompt["point_labels"] = np.array([0] * len(self.background_points))
if self.foreground_points:
coords = np.array(self.foreground_points)
prompt["point_coords"] = (
coords if "point_coords" not in prompt else np.concatenate((prompt["point_coords"], coords))
)
labels = np.array([1] * len(self.foreground_points))
prompt["point_labels"] = (
labels if "point_labels" not in prompt else np.concatenate((prompt["point_labels"], labels))
)
if self.box_points:
prompt["box"] = np.array([coord for batch in self.box_points for xy in batch for coord in xy]).reshape(
len(self.box_points), 4
)
return prompt
def facebook_prompt_encoder_kwargs(self, device: torch.device | None = None):
prompt = self.facebook_predict_kwargs()
coords: Tensor | None = None
labels: Tensor | None = None
boxes: Tensor | None = None
if "point_coords" in prompt:
coords = torch.as_tensor(prompt["point_coords"], dtype=torch.float, device=device).unsqueeze(0)
if "point_labels" in prompt:
labels = torch.as_tensor(prompt["point_labels"], dtype=torch.int, device=device).unsqueeze(0)
if "box" in prompt:
boxes = torch.as_tensor(prompt["box"], dtype=torch.float, device=device).unsqueeze(0)
points = (coords, labels) if coords is not None else None
# TODO: support masks
return {"points": points, "boxes": boxes, "masks": None}
def intersection_over_union(
input_mask: Bool[Tensor, "height width"], other_mask: Bool[Tensor, "height width"]
) -> float:
inter = (input_mask & other_mask).sum(dtype=torch.float32).item()
union = (input_mask | other_mask).sum(dtype=torch.float32).item()
return inter / union if union > 0 else 1.0