mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add Segment Anything (SAM) to foundational models
Note: dense prompts (i.e. masks) support is still partial (see MaskEncoder) Co-authored-by: Cédric Deltheil <cedric@deltheil.me>
This commit is contained in:
parent
2faff9f57a
commit
282578ddc0
35
poetry.lock
generated
35
poetry.lock
generated
|
@ -651,21 +651,19 @@ tests = ["asttokens", "littleutils", "pytest", "rich"]
|
|||
|
||||
[[package]]
|
||||
name = "filelock"
|
||||
version = "3.12.3"
|
||||
version = "3.12.4"
|
||||
description = "A platform independent file lock."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "filelock-3.12.3-py3-none-any.whl", hash = "sha256:f067e40ccc40f2b48395a80fcbd4728262fab54e232e090a4063ab804179efeb"},
|
||||
{file = "filelock-3.12.3.tar.gz", hash = "sha256:0ecc1dd2ec4672a10c8550a8182f1bd0c0a5088470ecd5a125e45f49472fac3d"},
|
||||
{file = "filelock-3.12.4-py3-none-any.whl", hash = "sha256:08c21d87ded6e2b9da6728c3dff51baf1dcecf973b768ef35bcbc3447edb9ad4"},
|
||||
{file = "filelock-3.12.4.tar.gz", hash = "sha256:2e6f249f1f3654291606e046b09f1fd5eac39b360664c27f5aad072012f8bcbd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
typing-extensions = {version = ">=4.7.1", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo (>=2023.7.26)", "sphinx (>=7.1.2)", "sphinx-autodoc-typehints (>=1.24)"]
|
||||
testing = ["covdefaults (>=2.3)", "coverage (>=7.3)", "diff-cover (>=7.7)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "pytest-timeout (>=2.1)"]
|
||||
typing = ["typing-extensions (>=4.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "frozenlist"
|
||||
|
@ -2507,6 +2505,25 @@ dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyl
|
|||
doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"]
|
||||
test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
|
||||
|
||||
[[package]]
|
||||
name = "segment_anything"
|
||||
version = "1.0"
|
||||
description = ""
|
||||
optional = true
|
||||
python-versions = "*"
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.extras]
|
||||
all = ["matplotlib", "onnx", "onnxruntime", "opencv-python", "pycocotools"]
|
||||
dev = ["black", "flake8", "isort", "mypy"]
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/facebookresearch/segment-anything"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "6fdee8f2727f4506cfbbe553e23b895e27956588"
|
||||
|
||||
[[package]]
|
||||
name = "sentry-sdk"
|
||||
version = "1.31.0"
|
||||
|
@ -3289,11 +3306,11 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link
|
|||
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
|
||||
|
||||
[extras]
|
||||
conversion = ["diffusers", "transformers"]
|
||||
test = ["diffusers", "invisible-watermark", "piq", "transformers"]
|
||||
conversion = ["diffusers", "segment-anything", "transformers"]
|
||||
test = ["diffusers", "invisible-watermark", "piq", "segment-anything", "transformers"]
|
||||
training = ["bitsandbytes", "datasets", "loguru", "prodigyopt", "pydantic", "scipy", "tomli", "torchvision", "wandb"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.12"
|
||||
content-hash = "b2646e04015b05eba32e6556b4aa37d3c5c31218c09fd0ea53d7540af558b7de"
|
||||
content-hash = "6e35f01f2fc8611203da972939170d18371182443e4c6f45b4cfc40b4e785dff"
|
||||
|
|
|
@ -29,11 +29,12 @@ diffusers = {version = "^0.18.0", optional = true}
|
|||
transformers = {version = "^4.27.4", optional = true}
|
||||
piq = {version = "^0.7.1", optional = true}
|
||||
invisible-watermark = {version = "^0.2.0", optional = true}
|
||||
segment-anything = {git = "https://github.com/facebookresearch/segment-anything", optional = true}
|
||||
|
||||
[tool.poetry.extras]
|
||||
training = ["datasets", "tomli", "wandb", "loguru", "bitsandbytes", "prodigyopt", "pydantic", "scipy", "torchvision"]
|
||||
conversion = ["diffusers", "transformers"]
|
||||
test = ["diffusers", "transformers", "piq", "invisible-watermark"]
|
||||
conversion = ["diffusers", "transformers", "segment-anything"]
|
||||
test = ["diffusers", "transformers", "piq", "invisible-watermark", "segment-anything"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^23.1.0"
|
||||
|
|
236
scripts/conversion/convert_segment_anything.py
Normal file
236
scripts/conversion/convert_segment_anything.py
Normal file
|
@ -0,0 +1,236 @@
|
|||
import argparse
|
||||
import types
|
||||
from typing import Any, Callable, cast
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import manual_seed, save_to_safetensors
|
||||
from refiners.foundationals.segment_anything.image_encoder import SAMViTH
|
||||
from refiners.foundationals.segment_anything.prompt_encoder import PointEncoder, MaskEncoder
|
||||
|
||||
from segment_anything import build_sam_vit_h # type: ignore
|
||||
from segment_anything.modeling.common import LayerNorm2d # type: ignore
|
||||
|
||||
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||
|
||||
|
||||
class FacebookSAM(nn.Module):
|
||||
image_encoder: nn.Module
|
||||
prompt_encoder: nn.Module
|
||||
mask_decoder: nn.Module
|
||||
|
||||
|
||||
build_sam_vit_h = cast(Callable[[], FacebookSAM], build_sam_vit_h)
|
||||
|
||||
|
||||
assert issubclass(LayerNorm2d, nn.Module)
|
||||
custom_layers = {LayerNorm2d: fl.LayerNorm2d}
|
||||
|
||||
|
||||
class Args(argparse.Namespace):
|
||||
source_path: str
|
||||
output_path: str
|
||||
half: bool
|
||||
verbose: bool
|
||||
|
||||
|
||||
def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
|
||||
state_dict: dict[str, Tensor] = {
|
||||
"no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore
|
||||
}
|
||||
|
||||
refiners_mask_encoder = MaskEncoder()
|
||||
# TODO: handle other weights
|
||||
refiners_mask_encoder.load_state_dict(state_dict=state_dict, strict=False)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_point_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
|
||||
manual_seed(seed=0)
|
||||
point_embeddings: list[Tensor] = [pe.weight for pe in prompt_encoder.point_embeddings] + [prompt_encoder.not_a_point_embed.weight] # type: ignore
|
||||
pe = prompt_encoder.pe_layer.positional_encoding_gaussian_matrix # type: ignore
|
||||
assert isinstance(pe, Tensor)
|
||||
state_dict: dict[str, Tensor] = {
|
||||
"Residual.Chain.PointTypeEmbedding.weight": nn.Parameter(data=torch.cat(tensors=point_embeddings, dim=0)),
|
||||
"CoordinateEncoder.Linear.weight": nn.Parameter(data=pe.T.contiguous()),
|
||||
}
|
||||
|
||||
refiners_prompt_encoder = PointEncoder()
|
||||
refiners_prompt_encoder.load_state_dict(state_dict=state_dict)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_vit(vit: nn.Module) -> dict[str, Tensor]:
|
||||
manual_seed(seed=0)
|
||||
refiners_sam_vit_h = SAMViTH()
|
||||
|
||||
converter = ModelConverter(
|
||||
source_model=vit,
|
||||
target_model=refiners_sam_vit_h,
|
||||
custom_layer_mapping=custom_layers, # type: ignore
|
||||
)
|
||||
converter.skip_init_check = True
|
||||
|
||||
x = torch.randn(1, 3, 1024, 1024)
|
||||
mapping = converter.map_state_dicts(source_args=(x,))
|
||||
assert mapping
|
||||
|
||||
mapping["PositionalEncoder.Chain.Parameter.parameter"] = "pos_embed"
|
||||
|
||||
target_state_dict = refiners_sam_vit_h.state_dict()
|
||||
del target_state_dict["PositionalEncoder.Chain.Parameter.parameter"]
|
||||
|
||||
source_state_dict = vit.state_dict()
|
||||
pos_embed = source_state_dict["pos_embed"]
|
||||
del source_state_dict["pos_embed"]
|
||||
|
||||
target_rel_keys = [
|
||||
(
|
||||
f"Transformer.TransformerLayer_{i}.Residual_1.Chain.FusedSelfAttention.RelativePositionAttention.horizontal_embedding",
|
||||
f"Transformer.TransformerLayer_{i}.Residual_1.Chain.FusedSelfAttention.RelativePositionAttention.vertical_embedding",
|
||||
)
|
||||
for i in range(1, 33)
|
||||
]
|
||||
source_rel_keys = [(f"blocks.{i}.attn.rel_pos_w", f"blocks.{i}.attn.rel_pos_h") for i in range(32)]
|
||||
|
||||
rel_items: dict[str, Tensor] = {}
|
||||
|
||||
for (key_w, key_h), (target_key_w, target_key_h) in zip(source_rel_keys, target_rel_keys):
|
||||
rel_items[target_key_w] = source_state_dict[key_w]
|
||||
rel_items[target_key_h] = source_state_dict[key_h]
|
||||
del source_state_dict[key_w]
|
||||
del source_state_dict[key_h]
|
||||
del target_state_dict[target_key_w]
|
||||
del target_state_dict[target_key_h]
|
||||
|
||||
converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
|
||||
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
||||
)
|
||||
|
||||
converted_source["PositionalEncoder.Chain.Parameter.parameter"] = pos_embed # type: ignore
|
||||
converted_source.update(rel_items)
|
||||
|
||||
refiners_sam_vit_h.load_state_dict(state_dict=converted_source)
|
||||
assert converter.compare_models((x,), threshold=1e-3)
|
||||
|
||||
return converted_source
|
||||
|
||||
|
||||
def convert_mask_decoder(mask_decoder: nn.Module) -> dict[str, Tensor]:
|
||||
manual_seed(seed=0)
|
||||
|
||||
refiners_mask_decoder = MaskDecoder()
|
||||
|
||||
image_embedding = torch.randn(1, 256, 64, 64)
|
||||
dense_positional_embedding = torch.randn(1, 256, 64, 64)
|
||||
point_embedding = torch.randn(1, 3, 256)
|
||||
mask_embedding = torch.randn(1, 256, 64, 64)
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from segment_anything.modeling.common import LayerNorm2d # type: ignore
|
||||
|
||||
assert issubclass(LayerNorm2d, nn.Module)
|
||||
custom_layers = {LayerNorm2d: fl.LayerNorm2d}
|
||||
|
||||
converter = ModelConverter(
|
||||
source_model=mask_decoder,
|
||||
target_model=refiners_mask_decoder,
|
||||
custom_layer_mapping=custom_layers, # type: ignore
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"image_embeddings": image_embedding,
|
||||
"image_pe": dense_positional_embedding,
|
||||
"sparse_prompt_embeddings": point_embedding,
|
||||
"dense_prompt_embeddings": mask_embedding,
|
||||
"multimask_output": True,
|
||||
}
|
||||
|
||||
refiners_mask_decoder.set_image_embedding(image_embedding)
|
||||
refiners_mask_decoder.set_point_embedding(point_embedding)
|
||||
refiners_mask_decoder.set_mask_embedding(mask_embedding)
|
||||
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
|
||||
|
||||
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
|
||||
assert mapping is not None
|
||||
mapping["IOUMaskEncoder"] = "iou_token"
|
||||
|
||||
state_dict = converter._convert_state_dict(source_state_dict=mask_decoder.state_dict(), target_state_dict=refiners_mask_decoder.state_dict(), state_dict_mapping=mapping) # type: ignore
|
||||
state_dict["IOUMaskEncoder.weight"] = torch.cat(tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0) # type: ignore
|
||||
|
||||
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
|
||||
|
||||
refiners_mask_decoder.set_image_embedding(image_embedding)
|
||||
refiners_mask_decoder.set_point_embedding(point_embedding)
|
||||
refiners_mask_decoder.set_mask_embedding(mask_embedding)
|
||||
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
|
||||
|
||||
# Perform (1) upscaling then (2) mask prediction in this order (= like in the official implementation) to make
|
||||
# `compare_models` happy (MaskPrediction's Matmul runs those in the reverse order by default)
|
||||
matmul = refiners_mask_decoder.ensure_find(fl.Matmul)
|
||||
|
||||
def forward_swapped_order(self: Any, *args: Any) -> Any:
|
||||
y = self[1](*args) # (1)
|
||||
x = self[0](*args) # (2)
|
||||
return torch.matmul(input=x, other=y)
|
||||
|
||||
matmul.forward = types.MethodType(forward_swapped_order, matmul)
|
||||
|
||||
assert converter.compare_models(source_args=inputs, target_args={}, threshold=1e-3)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Converts a Segment Anything ViT model to a Refiners SAMViTH model")
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
dest="source_path",
|
||||
default="sam_vit_h_4b8939.pth",
|
||||
# required=True,
|
||||
help="Path to the Segment Anything model weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default="segment-anything-h.safetensors",
|
||||
help="Output path for converted model (as safetensors).",
|
||||
)
|
||||
parser.add_argument("--half", action="store_true", default=False, help="Convert to half precision. Default: False")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Prints additional information during conversion. Default: False",
|
||||
)
|
||||
args = parser.parse_args(namespace=Args())
|
||||
|
||||
sam_h = build_sam_vit_h() # type: ignore
|
||||
sam_h.load_state_dict(state_dict=torch.load(f=args.source_path)) # type: ignore
|
||||
|
||||
vit_state_dict = convert_vit(vit=sam_h.image_encoder)
|
||||
mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder)
|
||||
point_encoder_state_dict = convert_point_encoder(prompt_encoder=sam_h.prompt_encoder)
|
||||
mask_encoder_state_dict = convert_mask_encoder(prompt_encoder=sam_h.prompt_encoder)
|
||||
|
||||
output_state_dict = {
|
||||
**{".".join(("image_encoder", key)): value for key, value in vit_state_dict.items()},
|
||||
**{".".join(("mask_decoder", key)): value for key, value in mask_decoder_state_dict.items()},
|
||||
**{".".join(("point_encoder", key)): value for key, value in point_encoder_state_dict.items()},
|
||||
**{".".join(("mask_encoder", key)): value for key, value in mask_encoder_state_dict.items()},
|
||||
}
|
||||
if args.half:
|
||||
output_state_dict = {key: value.half() for key, value in output_state_dict.items()}
|
||||
|
||||
save_to_safetensors(path=args.output_path, tensors=output_state_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
369
src/refiners/foundationals/segment_anything/image_encoder.py
Normal file
369
src/refiners/foundationals/segment_anything/image_encoder.py
Normal file
|
@ -0,0 +1,369 @@
|
|||
from torch import device as Device, dtype as DType, Tensor
|
||||
from refiners.fluxion.context import Contexts
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.utils import pad
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
|
||||
class PatchEncoder(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
patch_size: int = 16,
|
||||
use_bias: bool = True,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.patch_size = patch_size
|
||||
self.use_bias = use_bias
|
||||
super().__init__(
|
||||
fl.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=(self.patch_size, self.patch_size),
|
||||
stride=(self.patch_size, self.patch_size),
|
||||
use_bias=self.use_bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Permute(0, 2, 3, 1),
|
||||
)
|
||||
|
||||
|
||||
class PositionalEncoder(fl.Residual):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
image_embedding_size: tuple[int, int],
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.image_embedding_size = image_embedding_size
|
||||
super().__init__(
|
||||
fl.Parameter(
|
||||
1,
|
||||
image_embedding_size[0],
|
||||
image_embedding_size[1],
|
||||
embedding_dim,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class RelativePositionAttention(fl.WeightedModule):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
spatial_size: tuple[int, int],
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embedding_dim // num_heads
|
||||
self.spatial_size = spatial_size
|
||||
self.horizontal_embedding = nn.Parameter(
|
||||
data=torch.zeros(2 * spatial_size[0] - 1, self.head_dim, device=device, dtype=dtype)
|
||||
)
|
||||
self.vertical_embedding = nn.Parameter(
|
||||
data=torch.zeros(2 * spatial_size[1] - 1, self.head_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self) -> Device:
|
||||
return self.horizontal_embedding.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
return self.horizontal_embedding.dtype
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
batch, height, width, _ = x.shape
|
||||
x = (
|
||||
x.reshape(batch, width * height, 3, self.num_heads, -1)
|
||||
.permute(2, 0, 3, 1, 4)
|
||||
.reshape(3, batch * self.num_heads, width * height, -1)
|
||||
)
|
||||
query, key, value = x.unbind(dim=0)
|
||||
horizontal_relative_embedding, vertical_relative_embedding = self.compute_relative_embedding(x=query)
|
||||
attention = (query * self.head_dim**-0.5) @ key.transpose(dim0=-2, dim1=-1)
|
||||
# Order of operations is important here
|
||||
attention = (
|
||||
(attention.reshape(-1, height, width, height, width) + vertical_relative_embedding)
|
||||
+ horizontal_relative_embedding
|
||||
).reshape(attention.shape)
|
||||
attention = attention.softmax(dim=-1)
|
||||
attention = attention @ value
|
||||
attention = (
|
||||
attention.reshape(batch, self.num_heads, height, width, -1)
|
||||
.permute(0, 2, 3, 1, 4)
|
||||
.reshape(batch, height, width, -1)
|
||||
)
|
||||
return attention
|
||||
|
||||
def compute_relative_coords(self, size: int) -> Tensor:
|
||||
x, y = torch.meshgrid(torch.arange(end=size), torch.arange(end=size), indexing="ij")
|
||||
return x - y + size - 1
|
||||
|
||||
def compute_relative_embedding(self, x: Tensor) -> tuple[Tensor, Tensor]:
|
||||
width, height = self.spatial_size
|
||||
horizontal_coords = self.compute_relative_coords(size=width)
|
||||
vertical_coords = self.compute_relative_coords(size=height)
|
||||
horizontal_positional_embedding = self.horizontal_embedding[horizontal_coords]
|
||||
vertical_positional_embedding = self.vertical_embedding[vertical_coords]
|
||||
x = x.reshape(x.shape[0], width, height, -1)
|
||||
horizontal_relative_embedding = torch.einsum("bhwc,wkc->bhwk", x, horizontal_positional_embedding).unsqueeze(
|
||||
dim=-2
|
||||
)
|
||||
vertical_relative_embedding = torch.einsum("bhwc,hkc->bhwk", x, vertical_positional_embedding).unsqueeze(dim=-1)
|
||||
|
||||
return horizontal_relative_embedding, vertical_relative_embedding
|
||||
|
||||
|
||||
class FusedSelfAttention(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 768,
|
||||
spatial_size: tuple[int, int] = (64, 64),
|
||||
num_heads: int = 1,
|
||||
use_bias: bool = True,
|
||||
is_causal: bool = False,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
assert (
|
||||
embedding_dim % num_heads == 0
|
||||
), f"Embedding dim (embedding_dim={embedding_dim}) must be divisible by num heads (num_heads={num_heads})"
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.use_bias = use_bias
|
||||
self.is_causal = is_causal
|
||||
super().__init__(
|
||||
fl.Linear(
|
||||
in_features=self.embedding_dim,
|
||||
out_features=3 * self.embedding_dim,
|
||||
bias=self.use_bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
RelativePositionAttention(
|
||||
embedding_dim=self.embedding_dim,
|
||||
num_heads=self.num_heads,
|
||||
spatial_size=spatial_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Linear(
|
||||
in_features=self.embedding_dim,
|
||||
out_features=self.embedding_dim,
|
||||
bias=True,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class FeedForward(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
feedforward_dim: int,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.feedforward_dim = feedforward_dim
|
||||
super().__init__(
|
||||
fl.Linear(
|
||||
in_features=self.embedding_dim,
|
||||
out_features=self.feedforward_dim,
|
||||
bias=True,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.GeLU(),
|
||||
fl.Linear(
|
||||
in_features=self.feedforward_dim,
|
||||
out_features=self.embedding_dim,
|
||||
bias=True,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class WindowPartition(fl.ContextModule):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
batch, height, width, channels = x.shape
|
||||
context = self.use_context(context_name="window_partition")
|
||||
context.update({"original_height": height, "original_width": width})
|
||||
window_size = context["window_size"]
|
||||
padding_height = (window_size - height % window_size) % window_size
|
||||
padding_width = (window_size - width % window_size) % window_size
|
||||
if padding_height > 0 or padding_width > 0:
|
||||
x = pad(x=x, pad=(0, 0, 0, padding_width, 0, padding_height))
|
||||
padded_height, padded_width = height + padding_height, width + padding_width
|
||||
context.update({"padded_height": padded_height, "padded_width": padded_width})
|
||||
x = x.view(batch, padded_height // window_size, window_size, padded_width // window_size, window_size, channels)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, channels)
|
||||
return windows
|
||||
|
||||
|
||||
class WindowMerge(fl.ContextModule):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
context = self.use_context(context_name="window_partition")
|
||||
window_size = context["window_size"]
|
||||
padded_height, padded_width = context["padded_height"], context["padded_width"]
|
||||
original_height, original_width = context["original_height"], context["original_width"]
|
||||
batch_size = x.shape[0] // (padded_height * padded_width // window_size // window_size)
|
||||
x = x.view(batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, padded_height, padded_width, -1)
|
||||
if padded_height > original_height or padded_width > original_width:
|
||||
x = x[:, :original_height, :original_width, :].contiguous()
|
||||
return x
|
||||
|
||||
|
||||
class TransformerLayer(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
feedforward_dim: int,
|
||||
image_embedding_size: tuple[int, int],
|
||||
window_size: int | None = None,
|
||||
layer_norm_eps: float = 1e-6,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.feedforward_dim = feedforward_dim
|
||||
self.window_size = window_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.image_embedding_size = image_embedding_size
|
||||
attention_spatial_size = (window_size, window_size) if window_size is not None else image_embedding_size
|
||||
reshape_or_merge = (
|
||||
WindowMerge()
|
||||
if self.window_size is not None
|
||||
else fl.Reshape(self.image_embedding_size[0], self.image_embedding_size[1], embedding_dim)
|
||||
)
|
||||
super().__init__(
|
||||
fl.Residual(
|
||||
fl.LayerNorm(normalized_shape=embedding_dim, eps=self.layer_norm_eps, device=device, dtype=dtype),
|
||||
WindowPartition() if self.window_size is not None else fl.Identity(),
|
||||
FusedSelfAttention(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=num_heads,
|
||||
spatial_size=attention_spatial_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
reshape_or_merge,
|
||||
),
|
||||
fl.Residual(
|
||||
fl.LayerNorm(normalized_shape=embedding_dim, eps=self.layer_norm_eps, device=device, dtype=dtype),
|
||||
FeedForward(embedding_dim=embedding_dim, feedforward_dim=feedforward_dim, device=device, dtype=dtype),
|
||||
),
|
||||
)
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"window_partition": {"window_size": self.window_size}}
|
||||
|
||||
|
||||
class Neck(fl.Chain):
|
||||
def __init__(self, in_channels: int = 768, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
self.in_channels = in_channels
|
||||
super().__init__(
|
||||
fl.Permute(0, 3, 1, 2),
|
||||
fl.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=256,
|
||||
kernel_size=1,
|
||||
use_bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.LayerNorm2d(channels=256, device=device, dtype=dtype),
|
||||
fl.Conv2d(
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
use_bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.LayerNorm2d(channels=256, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
|
||||
class Transformer(fl.Chain):
|
||||
pass
|
||||
|
||||
|
||||
class SAMViT(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
global_attention_indices: tuple[int, ...] | None = None,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.image_size = (1024, 1024)
|
||||
self.patch_size = 16
|
||||
self.window_size = 14
|
||||
self.image_embedding_size = (self.image_size[0] // self.patch_size, self.image_size[1] // self.patch_size)
|
||||
self.feed_forward_dim = 4 * self.embedding_dim
|
||||
self.global_attention_indices = global_attention_indices or tuple()
|
||||
super().__init__(
|
||||
PatchEncoder(
|
||||
in_channels=3, out_channels=embedding_dim, patch_size=self.patch_size, device=device, dtype=dtype
|
||||
),
|
||||
PositionalEncoder(
|
||||
embedding_dim=embedding_dim, image_embedding_size=self.image_embedding_size, device=device, dtype=dtype
|
||||
),
|
||||
Transformer(
|
||||
TransformerLayer(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=num_heads,
|
||||
feedforward_dim=self.feed_forward_dim,
|
||||
window_size=self.window_size if i not in self.global_attention_indices else None,
|
||||
image_embedding_size=self.image_embedding_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
),
|
||||
Neck(in_channels=embedding_dim, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
|
||||
class SAMViTH(SAMViT):
|
||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=1280,
|
||||
num_layers=32,
|
||||
num_heads=16,
|
||||
global_attention_indices=(7, 15, 23, 31),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
264
src/refiners/foundationals/segment_anything/mask_decoder.py
Normal file
264
src/refiners/foundationals/segment_anything/mask_decoder.py
Normal file
|
@ -0,0 +1,264 @@
|
|||
import refiners.fluxion.layers as fl
|
||||
from torch import device as Device, dtype as DType, Tensor, nn
|
||||
import torch
|
||||
|
||||
from refiners.foundationals.segment_anything.transformer import (
|
||||
SparseCrossDenseAttention,
|
||||
TwoWayTranformerLayer,
|
||||
)
|
||||
from refiners.fluxion.context import Contexts
|
||||
|
||||
|
||||
class EmbeddingsAggregator(fl.ContextModule):
|
||||
def __init__(self, num_output_mask: int = 3) -> None:
|
||||
super().__init__()
|
||||
self.num_mask_tokens = num_output_mask
|
||||
|
||||
def forward(self, iou_mask_tokens: Tensor) -> Tensor:
|
||||
mask_decoder = self.ensure_parent
|
||||
mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder")
|
||||
image_embedding = mask_decoder_context["image_embedding"]
|
||||
point_embedding = mask_decoder_context["point_embedding"]
|
||||
mask_embedding = mask_decoder_context["mask_embedding"]
|
||||
dense_positional_embedding = mask_decoder_context["dense_positional_embedding"]
|
||||
|
||||
sparse_embedding = torch.cat(tensors=(iou_mask_tokens, point_embedding), dim=1)
|
||||
dense_embedding = (image_embedding + mask_embedding).flatten(start_dim=2).transpose(1, 2)
|
||||
if dense_positional_embedding.shape != dense_embedding.shape:
|
||||
dense_positional_embedding = dense_positional_embedding.flatten(start_dim=2).transpose(1, 2)
|
||||
|
||||
mask_decoder_context.update(
|
||||
{
|
||||
"dense_embedding": dense_embedding,
|
||||
"dense_positional_embedding": dense_positional_embedding,
|
||||
"sparse_embedding": sparse_embedding,
|
||||
}
|
||||
)
|
||||
mask_decoder.set_context(context="mask_decoder", value=mask_decoder_context)
|
||||
|
||||
return sparse_embedding
|
||||
|
||||
|
||||
class Transformer(fl.Chain):
|
||||
pass
|
||||
|
||||
|
||||
class Hypernetworks(fl.Concatenate):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 256,
|
||||
num_layers: int = 3,
|
||||
num_mask_tokens: int = 3,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_layers = num_layers
|
||||
self.num_mask_tokens = num_mask_tokens
|
||||
|
||||
super().__init__(
|
||||
*[
|
||||
fl.Chain(
|
||||
fl.Slicing(dim=1, start=i + 1, length=1),
|
||||
fl.MultiLinear(
|
||||
input_dim=embedding_dim,
|
||||
output_dim=embedding_dim // 8,
|
||||
inner_dim=embedding_dim,
|
||||
num_layers=num_layers,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
for i in range(num_mask_tokens + 1)
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
|
||||
class DenseEmbeddingUpscaling(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 256,
|
||||
dense_embedding_side_dim: int = 64,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.dense_embedding_side_dim = dense_embedding_side_dim
|
||||
|
||||
super().__init__(
|
||||
fl.UseContext(context="mask_decoder", key="dense_embedding"),
|
||||
fl.Transpose(dim0=1, dim1=2),
|
||||
fl.Reshape(embedding_dim, dense_embedding_side_dim, dense_embedding_side_dim),
|
||||
fl.ConvTranspose2d(
|
||||
in_channels=embedding_dim,
|
||||
out_channels=embedding_dim // 4,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.LayerNorm2d(channels=embedding_dim // 4, device=device, dtype=dtype),
|
||||
fl.GeLU(),
|
||||
fl.ConvTranspose2d(
|
||||
in_channels=embedding_dim // 4,
|
||||
out_channels=embedding_dim // 8,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.GeLU(),
|
||||
fl.Flatten(start_dim=2),
|
||||
)
|
||||
|
||||
|
||||
class IOUMaskEncoder(fl.WeightedModule):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 256,
|
||||
num_mask_tokens: int = 4,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_mask_tokens = num_mask_tokens
|
||||
# aka prompt tokens + output token (for IoU scores prediction)
|
||||
self.weight = nn.Parameter(data=torch.randn(num_mask_tokens + 1, embedding_dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self) -> Tensor:
|
||||
return self.weight.unsqueeze(dim=0)
|
||||
|
||||
|
||||
class MaskPrediction(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_mask_tokens: int,
|
||||
num_layers: int = 3,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_mask_tokens = num_mask_tokens
|
||||
self.num_layers = num_layers
|
||||
super().__init__(
|
||||
fl.Matmul(
|
||||
input=Hypernetworks(
|
||||
embedding_dim=embedding_dim,
|
||||
num_layers=num_layers,
|
||||
num_mask_tokens=num_mask_tokens,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||
),
|
||||
fl.Slicing(dim=1, start=1, length=num_mask_tokens),
|
||||
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim),
|
||||
)
|
||||
|
||||
|
||||
class IOUPrediction(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_layers: int,
|
||||
num_mask_tokens: int,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_layers = num_layers
|
||||
super().__init__(
|
||||
fl.Slicing(dim=1, start=0, length=1),
|
||||
fl.Squeeze(dim=0),
|
||||
fl.MultiLinear(
|
||||
input_dim=embedding_dim,
|
||||
output_dim=num_mask_tokens + 1,
|
||||
inner_dim=embedding_dim,
|
||||
num_layers=num_layers,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Slicing(dim=-1, start=1, length=num_mask_tokens),
|
||||
)
|
||||
|
||||
|
||||
class MaskDecoder(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 256,
|
||||
feed_forward_dim: int = 2048,
|
||||
num_layers: int = 2,
|
||||
num_output_mask: int = 3,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_mask_tokens = num_output_mask
|
||||
self.feed_forward_dim = feed_forward_dim
|
||||
self.num_layers = num_layers
|
||||
|
||||
super().__init__(
|
||||
IOUMaskEncoder(
|
||||
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask + 1, device=device, dtype=dtype
|
||||
),
|
||||
EmbeddingsAggregator(num_output_mask=num_output_mask),
|
||||
Transformer(
|
||||
*(
|
||||
TwoWayTranformerLayer(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=8,
|
||||
feed_forward_dim=feed_forward_dim,
|
||||
use_residual_self_attention=i > 0,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
),
|
||||
SparseCrossDenseAttention(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||
),
|
||||
fl.Parallel(
|
||||
MaskPrediction(
|
||||
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask, device=device, dtype=dtype
|
||||
),
|
||||
IOUPrediction(
|
||||
embedding_dim=embedding_dim,
|
||||
num_layers=3,
|
||||
num_mask_tokens=num_output_mask,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {
|
||||
"mask_decoder": {
|
||||
"image_embedding": None,
|
||||
"point_embedding": None,
|
||||
"mask_embedding": None,
|
||||
"dense_positional_embedding": None,
|
||||
}
|
||||
}
|
||||
|
||||
def set_image_embedding(self, image_embedding: Tensor) -> None:
|
||||
mask_decoder_context = self.use_context(context_name="mask_decoder")
|
||||
mask_decoder_context["image_embedding"] = image_embedding
|
||||
|
||||
def set_point_embedding(self, point_embedding: Tensor) -> None:
|
||||
mask_decoder_context = self.use_context(context_name="mask_decoder")
|
||||
mask_decoder_context["point_embedding"] = point_embedding
|
||||
|
||||
def set_mask_embedding(self, mask_embedding: Tensor) -> None:
|
||||
mask_decoder_context = self.use_context(context_name="mask_decoder")
|
||||
mask_decoder_context["mask_embedding"] = mask_embedding
|
||||
|
||||
def set_dense_positional_embedding(self, dense_positional_embedding: Tensor) -> None:
|
||||
mask_decoder_context = self.use_context(context_name="mask_decoder")
|
||||
mask_decoder_context["dense_positional_embedding"] = dense_positional_embedding
|
166
src/refiners/foundationals/segment_anything/model.py
Normal file
166
src/refiners/foundationals/segment_anything/model.py
Normal file
|
@ -0,0 +1,166 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Sequence
|
||||
from PIL import Image
|
||||
from torch import device as Device, dtype as DType, Tensor
|
||||
import numpy as np
|
||||
import torch
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.utils import image_to_tensor, normalize, pad, interpolate
|
||||
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
|
||||
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageEmbedding:
|
||||
features: Tensor
|
||||
original_image_size: tuple[int, int] # (height, width)
|
||||
|
||||
|
||||
class SegmentAnything(fl.Module):
|
||||
mask_threshold: float = 0.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_encoder: SAMViT,
|
||||
point_encoder: PointEncoder,
|
||||
mask_encoder: MaskEncoder,
|
||||
mask_decoder: MaskDecoder,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device: Device = device if isinstance(device, Device) else Device(device=device)
|
||||
self.dtype = dtype
|
||||
self.image_encoder = image_encoder.to(device=self.device, dtype=self.dtype)
|
||||
self.point_encoder = point_encoder.to(device=self.device, dtype=self.dtype)
|
||||
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
|
||||
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
|
||||
original_size = (image.height, image.width)
|
||||
target_size = self.compute_target_size(original_size)
|
||||
return ImageEmbedding(
|
||||
features=self.image_encoder(self.preprocess_image(image=image, target_size=target_size)),
|
||||
original_image_size=original_size,
|
||||
)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
input: Image.Image | ImageEmbedding,
|
||||
foreground_points: Sequence[tuple[float, float]] | None = None,
|
||||
background_points: Sequence[tuple[float, float]] | None = None,
|
||||
box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
|
||||
masks: Sequence[Image.Image] | None = None,
|
||||
binarize: bool = True,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
if isinstance(input, ImageEmbedding):
|
||||
original_size = input.original_image_size
|
||||
target_size = self.compute_target_size(original_size)
|
||||
image_embedding = input.features
|
||||
else:
|
||||
original_size = (input.height, input.width)
|
||||
target_size = self.compute_target_size(original_size)
|
||||
image_embedding = self.image_encoder(self.preprocess_image(image=input, target_size=target_size))
|
||||
|
||||
coordinates, type_mask = self.point_encoder.points_to_tensor(
|
||||
foreground_points=foreground_points,
|
||||
background_points=background_points,
|
||||
box_points=box_points,
|
||||
)
|
||||
self.point_encoder.set_type_mask(type_mask=type_mask)
|
||||
|
||||
if masks is not None:
|
||||
mask_tensor = torch.stack(
|
||||
tensors=[image_to_tensor(image=mask, device=self.device, dtype=self.dtype) for mask in masks]
|
||||
)
|
||||
mask_embedding = self.mask_encoder(mask_tensor)
|
||||
else:
|
||||
mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(
|
||||
image_embedding_size=self.image_encoder.image_embedding_size
|
||||
)
|
||||
point_embedding = self.point_encoder(
|
||||
self.normalize(coordinates, target_size=target_size, original_size=original_size)
|
||||
)
|
||||
dense_positional_embedding = self.point_encoder.get_dense_positional_embedding(
|
||||
image_embedding_size=self.image_encoder.image_embedding_size
|
||||
)
|
||||
|
||||
self.mask_decoder.set_image_embedding(image_embedding=image_embedding)
|
||||
self.mask_decoder.set_mask_embedding(mask_embedding=mask_embedding)
|
||||
self.mask_decoder.set_point_embedding(point_embedding=point_embedding)
|
||||
self.mask_decoder.set_dense_positional_embedding(dense_positional_embedding=dense_positional_embedding)
|
||||
|
||||
low_res_masks, iou_predictions = self.mask_decoder()
|
||||
|
||||
high_res_masks = self.postprocess_masks(
|
||||
masks=low_res_masks, target_size=target_size, original_size=original_size
|
||||
)
|
||||
|
||||
if binarize:
|
||||
high_res_masks = high_res_masks > self.mask_threshold
|
||||
|
||||
return high_res_masks, iou_predictions, low_res_masks
|
||||
|
||||
@property
|
||||
def image_size(self) -> int:
|
||||
w, h = self.image_encoder.image_size
|
||||
assert w == h
|
||||
return w
|
||||
|
||||
def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]:
|
||||
oldh, oldw = size
|
||||
scale = self.image_size * 1.0 / max(oldh, oldw)
|
||||
newh, neww = oldh * scale, oldw * scale
|
||||
neww = int(neww + 0.5)
|
||||
newh = int(newh + 0.5)
|
||||
return (newh, neww)
|
||||
|
||||
def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor:
|
||||
h, w = target_size
|
||||
padh = self.image_size - h
|
||||
padw = self.image_size - w
|
||||
image_tensor = torch.tensor(
|
||||
np.array(image.resize((w, h), resample=Image.Resampling.BILINEAR)).astype(np.float32).transpose(2, 0, 1),
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
).unsqueeze(0)
|
||||
return pad(
|
||||
normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]), (0, padw, 0, padh)
|
||||
)
|
||||
|
||||
def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
|
||||
coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size
|
||||
coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size
|
||||
return coordinates
|
||||
|
||||
def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
|
||||
masks = interpolate(masks, factor=torch.Size((self.image_size, self.image_size)), mode="bilinear")
|
||||
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
|
||||
masks = interpolate(masks, factor=torch.Size(original_size), mode="bilinear")
|
||||
return masks
|
||||
|
||||
|
||||
class SegmentAnythingH(SegmentAnything):
|
||||
def __init__(
|
||||
self,
|
||||
image_encoder: SAMViTH | None = None,
|
||||
point_encoder: PointEncoder | None = None,
|
||||
mask_encoder: MaskEncoder | None = None,
|
||||
mask_decoder: MaskDecoder | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
image_encoder = image_encoder or SAMViTH()
|
||||
point_encoder = point_encoder or PointEncoder()
|
||||
mask_encoder = mask_encoder or MaskEncoder()
|
||||
mask_decoder = mask_decoder or MaskDecoder()
|
||||
|
||||
super().__init__(
|
||||
image_encoder=image_encoder,
|
||||
point_encoder=point_encoder,
|
||||
mask_encoder=mask_encoder,
|
||||
mask_decoder=mask_decoder,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
190
src/refiners/foundationals/segment_anything/prompt_encoder.py
Normal file
190
src/refiners/foundationals/segment_anything/prompt_encoder.py
Normal file
|
@ -0,0 +1,190 @@
|
|||
from enum import Enum, auto
|
||||
from collections.abc import Sequence
|
||||
from torch import device as Device, dtype as DType, Tensor, nn
|
||||
import torch
|
||||
from jaxtyping import Float, Int
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.context import Contexts
|
||||
|
||||
|
||||
class CoordinateEncoder(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
num_positional_features: int = 64,
|
||||
scale: float = 1,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.num_positional_features = num_positional_features
|
||||
self.scale = scale
|
||||
|
||||
super().__init__(
|
||||
fl.Multiply(scale=2, bias=-1),
|
||||
fl.Linear(in_features=2, out_features=num_positional_features, bias=False, device=device, dtype=dtype),
|
||||
fl.Multiply(scale=2 * torch.pi * self.scale),
|
||||
fl.Concatenate(fl.Sin(), fl.Cos(), dim=-1),
|
||||
)
|
||||
|
||||
|
||||
class PointType(Enum):
|
||||
BACKGROUND = auto()
|
||||
FOREGROUND = auto()
|
||||
BOX_TOP_LEFT = auto()
|
||||
BOX_BOTTOM_RIGHT = auto()
|
||||
NOT_A_POINT = auto()
|
||||
|
||||
|
||||
class PointTypeEmbedding(fl.WeightedModule, fl.ContextModule):
|
||||
def __init__(self, embedding_dim: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.weight = nn.Parameter(data=torch.randn(len(PointType), self.embedding_dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, type_mask: Int[Tensor, "1 num_points"]) -> Float[Tensor, "1 num_points embedding_dim"]:
|
||||
assert isinstance(type_mask, Tensor), "type_mask must be a Tensor."
|
||||
|
||||
embeddings = torch.zeros(*type_mask.shape, self.embedding_dim).to(device=type_mask.device)
|
||||
for type_id in PointType:
|
||||
mask = type_mask == type_id.value
|
||||
embeddings[mask] = self.weight[type_id.value - 1]
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class PointEncoder(fl.Chain):
|
||||
def __init__(
|
||||
self, embedding_dim: int = 256, scale: float = 1, device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> None:
|
||||
assert embedding_dim % 2 == 0, "embedding_dim must be divisible by 2."
|
||||
self.embedding_dim = embedding_dim
|
||||
self.scale = scale
|
||||
|
||||
super().__init__(
|
||||
CoordinateEncoder(num_positional_features=embedding_dim // 2, scale=scale, device=device, dtype=dtype),
|
||||
fl.Lambda(func=self.pad),
|
||||
fl.Residual(
|
||||
fl.UseContext(context="point_encoder", key="type_mask"),
|
||||
PointTypeEmbedding(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||
),
|
||||
)
|
||||
|
||||
def pad(self, x: Tensor) -> Tensor:
|
||||
type_mask: Tensor = self.use_context("point_encoder")["type_mask"]
|
||||
if torch.any((type_mask == PointType.BOX_TOP_LEFT.value) | (type_mask == PointType.BOX_BOTTOM_RIGHT.value)):
|
||||
# Some boxes have been passed: no need to pad in this case
|
||||
return x
|
||||
type_mask = torch.cat(
|
||||
[type_mask, torch.full((type_mask.shape[0], 1), PointType.NOT_A_POINT.value, device=type_mask.device)],
|
||||
dim=1,
|
||||
)
|
||||
self.set_context(context="point_encoder", value={"type_mask": type_mask})
|
||||
return torch.cat([x, torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device)], dim=1)
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {
|
||||
"point_encoder": {
|
||||
"type_mask": None,
|
||||
}
|
||||
}
|
||||
|
||||
def set_type_mask(self, type_mask: Int[Tensor, "1 num_points"]) -> None:
|
||||
self.set_context(context="point_encoder", value={"type_mask": type_mask})
|
||||
|
||||
def get_dense_positional_embedding(
|
||||
self, image_embedding_size: tuple[int, int]
|
||||
) -> Float[Tensor, "num_positional_features height width"]:
|
||||
coordinate_encoder = self.ensure_find(layer_type=CoordinateEncoder)
|
||||
height, width = image_embedding_size
|
||||
grid = torch.ones((height, width), device=self.device, dtype=torch.float32)
|
||||
y_embedding = grid.cumsum(dim=0) - 0.5
|
||||
x_embedding = grid.cumsum(dim=1) - 0.5
|
||||
y_embedding = y_embedding / height
|
||||
x_embedding = x_embedding / width
|
||||
positional_embedding = (
|
||||
coordinate_encoder(torch.stack(tensors=[x_embedding, y_embedding], dim=-1))
|
||||
.permute(2, 0, 1)
|
||||
.unsqueeze(dim=0)
|
||||
)
|
||||
return positional_embedding
|
||||
|
||||
def points_to_tensor(
|
||||
self,
|
||||
foreground_points: Sequence[tuple[float, float]] | None = None,
|
||||
background_points: Sequence[tuple[float, float]] | None = None,
|
||||
not_a_points: Sequence[tuple[float, float]] | None = None,
|
||||
box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
|
||||
) -> tuple[Float[Tensor, "1 num_points 2"], Int[Tensor, "1 num_points"]]:
|
||||
foreground_points = foreground_points or []
|
||||
background_points = background_points or []
|
||||
not_a_points = not_a_points or []
|
||||
box_points = box_points or []
|
||||
top_left_points = [box[0] for box in box_points]
|
||||
bottom_right_points = [box[1] for box in box_points]
|
||||
coordinates: list[Tensor] = []
|
||||
type_ids: list[Tensor] = []
|
||||
|
||||
# Must be in sync with PointType enum
|
||||
for type_id, coords_seq in zip(
|
||||
PointType, [background_points, foreground_points, top_left_points, bottom_right_points, not_a_points]
|
||||
):
|
||||
if len(coords_seq) > 0:
|
||||
coords_tensor = torch.tensor(data=list(coords_seq), dtype=torch.float, device=self.device)
|
||||
coordinates.append(coords_tensor)
|
||||
point_ids = torch.tensor(data=[type_id.value] * len(coords_seq), dtype=torch.int, device=self.device)
|
||||
type_ids.append(point_ids)
|
||||
|
||||
all_coordinates = torch.cat(tensors=coordinates, dim=0).unsqueeze(dim=0)
|
||||
type_mask = torch.cat(tensors=type_ids, dim=0).unsqueeze(dim=0)
|
||||
|
||||
return all_coordinates, type_mask
|
||||
|
||||
|
||||
class MaskEncoder(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 256,
|
||||
intermediate_channels: int = 16,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.intermediate_channels = intermediate_channels
|
||||
super().__init__(
|
||||
fl.Conv2d(
|
||||
in_channels=1,
|
||||
out_channels=self.intermediate_channels // 4,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.LayerNorm2d(channels=self.intermediate_channels // 4, device=device, dtype=dtype),
|
||||
fl.GeLU(),
|
||||
fl.Conv2d(
|
||||
in_channels=self.intermediate_channels // 4,
|
||||
out_channels=self.intermediate_channels,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.LayerNorm2d(channels=self.intermediate_channels, device=device, dtype=dtype),
|
||||
fl.GeLU(),
|
||||
fl.Conv2d(
|
||||
in_channels=self.intermediate_channels,
|
||||
out_channels=self.embedding_dim,
|
||||
kernel_size=1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
self.register_parameter(
|
||||
"no_mask_embedding", nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype))
|
||||
)
|
||||
|
||||
def get_no_mask_dense_embedding(
|
||||
self, image_embedding_size: tuple[int, int], batch_size: int = 1
|
||||
) -> Float[Tensor, "batch embedding_dim image_embedding_height image_embedding_width"]:
|
||||
return self.no_mask_embedding.reshape(1, -1, 1, 1).expand(
|
||||
batch_size, -1, image_embedding_size[0], image_embedding_size[1]
|
||||
)
|
157
src/refiners/foundationals/segment_anything/transformer.py
Normal file
157
src/refiners/foundationals/segment_anything/transformer.py
Normal file
|
@ -0,0 +1,157 @@
|
|||
from torch import dtype as DType, device as Device
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
||||
class CrossAttention(fl.Attention):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
cross_embedding_dim: int | None = None,
|
||||
num_heads: int = 1,
|
||||
inner_dim: int | None = None,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=embedding_dim,
|
||||
key_embedding_dim=cross_embedding_dim,
|
||||
num_heads=num_heads,
|
||||
inner_dim=inner_dim,
|
||||
is_optimized=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.cross_embedding_dim = cross_embedding_dim or embedding_dim
|
||||
self.insert(index=0, module=fl.Parallel(fl.GetArg(index=0), fl.GetArg(index=1), fl.GetArg(index=1)))
|
||||
|
||||
|
||||
class FeedForward(fl.Residual):
|
||||
def __init__(
|
||||
self, embedding_dim: int, feed_forward_dim: int, device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.feed_forward_dim = feed_forward_dim
|
||||
super().__init__(
|
||||
fl.Linear(in_features=embedding_dim, out_features=feed_forward_dim, device=device, dtype=dtype),
|
||||
fl.ReLU(),
|
||||
fl.Linear(in_features=feed_forward_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
|
||||
class SparseSelfAttention(fl.Residual):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
inner_dim: int | None = None,
|
||||
num_heads: int = 1,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
add_sparse_embedding = fl.Residual(fl.UseContext(context="mask_decoder", key="sparse_embedding"))
|
||||
super().__init__(
|
||||
fl.Parallel(add_sparse_embedding, add_sparse_embedding, fl.Identity()),
|
||||
fl.Attention(
|
||||
embedding_dim=embedding_dim,
|
||||
inner_dim=inner_dim,
|
||||
num_heads=num_heads,
|
||||
is_optimized=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SparseCrossDenseAttention(fl.Residual):
|
||||
def __init__(
|
||||
self, embedding_dim: int, num_heads: int = 8, device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
super().__init__(
|
||||
fl.Parallel(
|
||||
fl.Residual(
|
||||
fl.UseContext(context="mask_decoder", key="sparse_embedding"),
|
||||
),
|
||||
fl.Sum(
|
||||
fl.UseContext(context="mask_decoder", key="dense_embedding"),
|
||||
fl.UseContext(context="mask_decoder", key="dense_positional_embedding"),
|
||||
),
|
||||
fl.UseContext(context="mask_decoder", key="dense_embedding"),
|
||||
),
|
||||
fl.Attention(
|
||||
embedding_dim=embedding_dim,
|
||||
inner_dim=embedding_dim // 2,
|
||||
num_heads=num_heads,
|
||||
is_optimized=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DenseCrossSparseAttention(fl.Chain):
|
||||
def __init__(
|
||||
self, embedding_dim: int, num_heads: int = 8, device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Parallel(
|
||||
fl.Sum(
|
||||
fl.UseContext(context="mask_decoder", key="dense_embedding"),
|
||||
fl.UseContext(context="mask_decoder", key="dense_positional_embedding"),
|
||||
),
|
||||
fl.Residual(
|
||||
fl.UseContext(context="mask_decoder", key="sparse_embedding"),
|
||||
),
|
||||
fl.Identity(),
|
||||
),
|
||||
fl.Attention(
|
||||
embedding_dim=embedding_dim,
|
||||
inner_dim=embedding_dim // 2,
|
||||
num_heads=num_heads,
|
||||
is_optimized=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TwoWayTranformerLayer(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int = 8,
|
||||
feed_forward_dim: int = 2048,
|
||||
use_residual_self_attention: bool = True,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.feed_forward_dim = feed_forward_dim
|
||||
|
||||
self_attention = (
|
||||
SparseSelfAttention(embedding_dim=embedding_dim, num_heads=num_heads, device=device, dtype=dtype)
|
||||
if use_residual_self_attention
|
||||
else fl.SelfAttention(
|
||||
embedding_dim=embedding_dim, num_heads=num_heads, is_optimized=False, device=device, dtype=dtype
|
||||
)
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
self_attention,
|
||||
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||
SparseCrossDenseAttention(embedding_dim=embedding_dim, num_heads=num_heads, device=device, dtype=dtype),
|
||||
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||
FeedForward(embedding_dim=embedding_dim, feed_forward_dim=feed_forward_dim, device=device, dtype=dtype),
|
||||
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||
fl.Passthrough(
|
||||
fl.Sum(
|
||||
fl.UseContext(context="mask_decoder", key="dense_embedding"),
|
||||
DenseCrossSparseAttention(
|
||||
embedding_dim=embedding_dim, num_heads=num_heads, device=device, dtype=dtype
|
||||
),
|
||||
),
|
||||
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||
fl.SetContext(context="mask_decoder", key="dense_embedding"),
|
||||
),
|
||||
)
|
|
@ -28,3 +28,8 @@ def test_e2e_path() -> Path:
|
|||
@fixture(scope="session")
|
||||
def test_textual_inversion_path() -> Path:
|
||||
return PARENT_PATH / "foundationals" / "clip" / "test_concepts_ref"
|
||||
|
||||
|
||||
@fixture(scope="session")
|
||||
def test_sam_path() -> Path:
|
||||
return PARENT_PATH / "foundationals" / "segment_anything"
|
||||
|
|
317
tests/foundationals/segment_anything/test_sam.py
Normal file
317
tests/foundationals/segment_anything/test_sam.py
Normal file
|
@ -0,0 +1,317 @@
|
|||
from math import isclose
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from warnings import warn
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
from refiners.fluxion import manual_seed
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
|
||||
from refiners.fluxion.utils import image_to_tensor
|
||||
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
|
||||
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
||||
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
|
||||
from tests.foundationals.segment_anything.utils import (
|
||||
FacebookSAM,
|
||||
FacebookSAMPredictor,
|
||||
SAMPrompt,
|
||||
intersection_over_union,
|
||||
)
|
||||
|
||||
# See predictor_example.ipynb official notebook (note: mask_input is not yet properly supported)
|
||||
PROMPTS: list[SAMPrompt] = [
|
||||
SAMPrompt(foreground_points=((500, 375),)),
|
||||
SAMPrompt(background_points=((500, 375),)),
|
||||
SAMPrompt(foreground_points=((500, 375), (1125, 625))),
|
||||
SAMPrompt(foreground_points=((500, 375),), background_points=((1125, 625),)),
|
||||
SAMPrompt(box_points=[[(425, 600), (700, 875)]]),
|
||||
SAMPrompt(box_points=[[(425, 600), (700, 875)]], background_points=((575, 750),)),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(params=PROMPTS)
|
||||
def prompt(request: pytest.FixtureRequest) -> SAMPrompt:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def one_prompt() -> SAMPrompt:
|
||||
return PROMPTS[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def facebook_sam_h_weights(test_weights_path: Path) -> Path:
|
||||
sam_h_weights = test_weights_path / "sam_vit_h_4b8939.pth"
|
||||
if not sam_h_weights.is_file():
|
||||
warn(f"could not find weights at {sam_h_weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return sam_h_weights
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sam_h_weights(test_weights_path: Path) -> Path:
|
||||
sam_h_weights = test_weights_path / "segment-anything-h.safetensors"
|
||||
if not sam_h_weights.is_file():
|
||||
warn(f"could not find weights at {sam_h_weights}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return sam_h_weights
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> FacebookSAM:
|
||||
from segment_anything import build_sam_vit_h # type: ignore
|
||||
|
||||
sam_h = cast(FacebookSAM, build_sam_vit_h())
|
||||
sam_h.load_state_dict(state_dict=torch.load(f=facebook_sam_h_weights)) # type: ignore
|
||||
return sam_h.to(device=test_device)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def facebook_sam_h_predictor(facebook_sam_h: FacebookSAM) -> FacebookSAMPredictor:
|
||||
from segment_anything import SamPredictor # type: ignore
|
||||
from segment_anything.modeling import Sam # type: ignore
|
||||
|
||||
predictor = SamPredictor(cast(Sam, facebook_sam_h))
|
||||
return cast(FacebookSAMPredictor, predictor)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH:
|
||||
sam_h = SegmentAnythingH(device=test_device)
|
||||
# TODO: make strict=True when the MasKEncoder conversion is done
|
||||
sam_h.load_from_safetensors(tensors_path=sam_h_weights, strict=False)
|
||||
return sam_h
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ref_path(test_sam_path: Path) -> Path:
|
||||
return test_sam_path / "test_sam_ref"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def truck(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "truck.jpg").convert("RGB")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
|
||||
manual_seed(seed=0)
|
||||
x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device)
|
||||
|
||||
attention = cast(nn.Module, facebook_sam_h.image_encoder.blocks[0].attn) # type: ignore
|
||||
|
||||
refiners_attention = FusedSelfAttention(
|
||||
embedding_dim=1280, num_heads=16, spatial_size=(14, 14), device=facebook_sam_h.device
|
||||
)
|
||||
refiners_attention.Linear_1.weight = attention.qkv.weight # type: ignore
|
||||
refiners_attention.Linear_1.bias = attention.qkv.bias # type: ignore
|
||||
refiners_attention.Linear_2.weight = attention.proj.weight # type: ignore
|
||||
refiners_attention.Linear_2.bias = attention.proj.bias # type: ignore
|
||||
refiners_attention.RelativePositionAttention.horizontal_embedding = attention.rel_pos_w
|
||||
refiners_attention.RelativePositionAttention.vertical_embedding = attention.rel_pos_h
|
||||
|
||||
y_1 = attention(x)
|
||||
assert y_1.shape == x.shape
|
||||
|
||||
y_2 = refiners_attention(x)
|
||||
assert y_2.shape == x.shape
|
||||
|
||||
assert torch.equal(input=y_1, other=y_2)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
|
||||
image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device)
|
||||
y_1 = facebook_sam_h.image_encoder(image_tensor)
|
||||
y_2 = sam_h.image_encoder(image_tensor)
|
||||
|
||||
assert torch.equal(input=y_1, other=y_2)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
|
||||
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
|
||||
refiners_prompt_encoder = sam_h.point_encoder
|
||||
|
||||
facebook_dense_pe: Tensor = cast(Tensor, facebook_prompt_encoder.get_dense_pe()) # type: ignore
|
||||
refiners_dense_pe = refiners_prompt_encoder.get_dense_positional_embedding(image_embedding_size=(64, 64))
|
||||
|
||||
assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
|
||||
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
|
||||
refiners_prompt_encoder = sam_h.mask_encoder
|
||||
|
||||
_, facebook_dense_pe = facebook_prompt_encoder(points=None, boxes=None, masks=None)
|
||||
refiners_dense_pe = refiners_prompt_encoder.get_no_mask_dense_embedding(image_embedding_size=(64, 64))
|
||||
|
||||
assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, prompt: SAMPrompt) -> None:
|
||||
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
|
||||
refiners_prompt_encoder = sam_h.point_encoder
|
||||
|
||||
facebook_sparse_pe, _ = facebook_prompt_encoder(
|
||||
**prompt.facebook_prompt_encoder_kwargs(device=facebook_sam_h.device)
|
||||
)
|
||||
|
||||
coordinates, type_mask = refiners_prompt_encoder.points_to_tensor(**prompt.__dict__)
|
||||
# Shift to center of pixel + normalize in [0, 1] (see `_embed_points` in segment-anything official repo)
|
||||
coordinates[:, :, 0] = (coordinates[:, :, 0] + 0.5) / 1024.0
|
||||
coordinates[:, :, 1] = (coordinates[:, :, 1] + 0.5) / 1024.0
|
||||
refiners_prompt_encoder.set_type_mask(type_mask=type_mask)
|
||||
refiners_sparse_pe = refiners_prompt_encoder(coordinates)
|
||||
|
||||
assert torch.equal(input=refiners_sparse_pe, other=facebook_sparse_pe)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
|
||||
dense_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
|
||||
dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
|
||||
sparse_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device)
|
||||
|
||||
refiners_layer = TwoWayTranformerLayer(
|
||||
embedding_dim=256, feed_forward_dim=2048, num_heads=8, device=facebook_sam_h.device
|
||||
)
|
||||
facebook_layer = facebook_sam_h.mask_decoder.transformer.layers[1] # type: ignore
|
||||
assert isinstance(facebook_layer, nn.Module)
|
||||
|
||||
refiners_layer.set_context(
|
||||
context="mask_decoder",
|
||||
value={
|
||||
"dense_embedding": dense_embedding,
|
||||
"dense_positional_embedding": dense_positional_embedding,
|
||||
"sparse_embedding": sparse_embedding,
|
||||
},
|
||||
)
|
||||
facebook_inputs = {
|
||||
"queries": sparse_embedding,
|
||||
"keys": dense_embedding,
|
||||
"query_pe": sparse_embedding,
|
||||
"key_pe": dense_positional_embedding,
|
||||
}
|
||||
|
||||
converter = ModelConverter(
|
||||
source_model=facebook_layer,
|
||||
target_model=refiners_layer,
|
||||
skip_output_check=True, # done below, manually
|
||||
)
|
||||
|
||||
assert converter.run(source_args=facebook_inputs, target_args=(sparse_embedding,))
|
||||
|
||||
refiners_layer.set_context(
|
||||
context="mask_decoder",
|
||||
value={
|
||||
"dense_embedding": dense_embedding,
|
||||
"dense_positional_embedding": dense_positional_embedding,
|
||||
"sparse_embedding": sparse_embedding,
|
||||
},
|
||||
)
|
||||
y_1 = facebook_layer(**facebook_inputs)[0]
|
||||
y_2 = refiners_layer(sparse_embedding)[0]
|
||||
|
||||
assert torch.equal(input=y_1, other=y_2)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
|
||||
manual_seed(seed=0)
|
||||
facebook_mask_decoder = facebook_sam_h.mask_decoder
|
||||
refiners_mask_decoder = sam_h.mask_decoder
|
||||
|
||||
image_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device)
|
||||
dense_positional_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device)
|
||||
point_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device)
|
||||
mask_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device)
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from segment_anything.modeling.common import LayerNorm2d # type: ignore
|
||||
|
||||
assert issubclass(LayerNorm2d, nn.Module)
|
||||
custom_layers = {LayerNorm2d: fl.LayerNorm2d}
|
||||
|
||||
converter = ModelConverter(
|
||||
source_model=facebook_mask_decoder,
|
||||
target_model=refiners_mask_decoder,
|
||||
custom_layer_mapping=custom_layers, # type: ignore
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"image_embeddings": image_embedding,
|
||||
"image_pe": dense_positional_embedding,
|
||||
"sparse_prompt_embeddings": point_embedding,
|
||||
"dense_prompt_embeddings": mask_embedding,
|
||||
"multimask_output": True,
|
||||
}
|
||||
|
||||
refiners_mask_decoder.set_image_embedding(image_embedding)
|
||||
refiners_mask_decoder.set_point_embedding(point_embedding)
|
||||
refiners_mask_decoder.set_mask_embedding(mask_embedding)
|
||||
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
|
||||
|
||||
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
|
||||
assert mapping is not None
|
||||
mapping["IOUMaskEncoder"] = "iou_token"
|
||||
|
||||
state_dict = converter._convert_state_dict(source_state_dict=facebook_mask_decoder.state_dict(), target_state_dict=refiners_mask_decoder.state_dict(), state_dict_mapping=mapping) # type: ignore
|
||||
state_dict["IOUMaskEncoder.weight"] = torch.cat([facebook_mask_decoder.iou_token.weight, facebook_mask_decoder.mask_tokens.weight], dim=0) # type: ignore
|
||||
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
|
||||
|
||||
facebook_output = facebook_mask_decoder(**inputs)
|
||||
|
||||
refiners_mask_decoder.set_image_embedding(image_embedding)
|
||||
refiners_mask_decoder.set_point_embedding(point_embedding)
|
||||
refiners_mask_decoder.set_mask_embedding(mask_embedding)
|
||||
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
|
||||
mask_prediction, iou_prediction = refiners_mask_decoder()
|
||||
|
||||
facebook_masks = facebook_output[0]
|
||||
facebook_prediction = facebook_output[1]
|
||||
|
||||
assert torch.equal(input=mask_prediction, other=facebook_masks)
|
||||
assert torch.equal(input=iou_prediction, other=facebook_prediction)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_predictor(
|
||||
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, prompt: SAMPrompt
|
||||
) -> None:
|
||||
predictor = facebook_sam_h_predictor
|
||||
predictor.set_image(np.array(truck))
|
||||
facebook_masks, facebook_scores, _ = predictor.predict(**prompt.facebook_predict_kwargs()) # type: ignore
|
||||
|
||||
assert len(facebook_masks) == 3
|
||||
|
||||
masks, scores, _ = sam_h.predict(truck, **prompt.__dict__)
|
||||
masks = masks.squeeze(0)
|
||||
scores = scores.squeeze(0)
|
||||
|
||||
assert len(masks) == 3
|
||||
|
||||
for i in range(3):
|
||||
mask_prediction = masks[i].cpu()
|
||||
facebook_mask = torch.as_tensor(facebook_masks[i])
|
||||
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
|
||||
assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None:
|
||||
masks_ref, scores_ref, _ = sam_h.predict(truck, **one_prompt.__dict__)
|
||||
|
||||
image_embedding = sam_h.compute_image_embedding(truck)
|
||||
masks, scores, _ = sam_h.predict(image_embedding, **one_prompt.__dict__)
|
||||
|
||||
assert torch.equal(masks, masks_ref)
|
||||
assert torch.equal(scores_ref, scores)
|
|
@ -0,0 +1,3 @@
|
|||
# Note about this data
|
||||
|
||||
`truck.jpg` is one of the [images](https://github.com/facebookresearch/segment-anything/tree/main/notebooks/images) used in the official [segment-anything notebooks](https://github.com/facebookresearch/segment-anything/tree/main/notebooks).
|
BIN
tests/foundationals/segment_anything/test_sam_ref/truck.jpg
Normal file
BIN
tests/foundationals/segment_anything/test_sam_ref/truck.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 265 KiB |
107
tests/foundationals/segment_anything/utils.py
Normal file
107
tests/foundationals/segment_anything/utils.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from jaxtyping import Bool
|
||||
from torch import Tensor, nn
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
NDArrayUInt8 = npt.NDArray[np.uint8]
|
||||
NDArray = npt.NDArray[Any]
|
||||
|
||||
|
||||
class SAMInput(TypedDict):
|
||||
image: Tensor
|
||||
original_size: tuple[int, int]
|
||||
point_coords: Tensor | None
|
||||
point_labels: Tensor | None
|
||||
boxes: Tensor | None
|
||||
mask_inputs: Tensor | None
|
||||
|
||||
|
||||
class SAMOutput(TypedDict):
|
||||
masks: Tensor
|
||||
iou_predictions: Tensor
|
||||
low_res_logits: Tensor
|
||||
|
||||
|
||||
class FacebookSAM(nn.Module):
|
||||
image_encoder: nn.Module
|
||||
prompt_encoder: nn.Module
|
||||
mask_decoder: nn.Module
|
||||
|
||||
def __call__(self, batched_input: list[SAMInput], multimask_output: bool) -> list[SAMOutput]: ...
|
||||
|
||||
@property
|
||||
def device(self) -> Any: ...
|
||||
|
||||
|
||||
class FacebookSAMPredictor:
|
||||
model: FacebookSAM
|
||||
|
||||
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...
|
||||
|
||||
def predict(
|
||||
self,
|
||||
point_coords: NDArray | None = None,
|
||||
point_labels: NDArray | None = None,
|
||||
box: NDArray | None = None,
|
||||
mask_input: NDArray | None = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
) -> tuple[NDArray, NDArray, NDArray]: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class SAMPrompt:
|
||||
foreground_points: Sequence[tuple[float, float]] | None = None
|
||||
background_points: Sequence[tuple[float, float]] | None = None
|
||||
box_points: Sequence[Sequence[tuple[float, float]]] | None = None
|
||||
# TODO: support masks
|
||||
# masks: Sequence[Image.Image] | None = None
|
||||
|
||||
def facebook_predict_kwargs(self) -> dict[str, NDArray]:
|
||||
prompt: dict[str, NDArray] = {}
|
||||
# Note: the order matters since `points_to_tensor` processes points that way (background -> foreground -> etc)
|
||||
if self.background_points:
|
||||
prompt["point_coords"] = np.array(self.background_points)
|
||||
prompt["point_labels"] = np.array([0] * len(self.background_points))
|
||||
if self.foreground_points:
|
||||
coords = np.array(self.foreground_points)
|
||||
prompt["point_coords"] = (
|
||||
coords if "point_coords" not in prompt else np.concatenate((prompt["point_coords"], coords))
|
||||
)
|
||||
labels = np.array([1] * len(self.foreground_points))
|
||||
prompt["point_labels"] = (
|
||||
labels if "point_labels" not in prompt else np.concatenate((prompt["point_labels"], labels))
|
||||
)
|
||||
if self.box_points:
|
||||
prompt["box"] = np.array([coord for batch in self.box_points for xy in batch for coord in xy]).reshape(
|
||||
len(self.box_points), 4
|
||||
)
|
||||
return prompt
|
||||
|
||||
def facebook_prompt_encoder_kwargs(self, device: torch.device | None = None):
|
||||
prompt = self.facebook_predict_kwargs()
|
||||
coords: Tensor | None = None
|
||||
labels: Tensor | None = None
|
||||
boxes: Tensor | None = None
|
||||
if "point_coords" in prompt:
|
||||
coords = torch.as_tensor(prompt["point_coords"], dtype=torch.float, device=device).unsqueeze(0)
|
||||
if "point_labels" in prompt:
|
||||
labels = torch.as_tensor(prompt["point_labels"], dtype=torch.int, device=device).unsqueeze(0)
|
||||
if "box" in prompt:
|
||||
boxes = torch.as_tensor(prompt["box"], dtype=torch.float, device=device).unsqueeze(0)
|
||||
points = (coords, labels) if coords is not None else None
|
||||
# TODO: support masks
|
||||
return {"points": points, "boxes": boxes, "masks": None}
|
||||
|
||||
|
||||
def intersection_over_union(
|
||||
input_mask: Bool[Tensor, "height width"], other_mask: Bool[Tensor, "height width"]
|
||||
) -> float:
|
||||
inter = (input_mask & other_mask).sum(dtype=torch.float32).item()
|
||||
union = (input_mask | other_mask).sum(dtype=torch.float32).item()
|
||||
return inter / union if union > 0 else 1.0
|
Loading…
Reference in a new issue