diff --git a/poetry.lock b/poetry.lock index 599aac9..6baa469 100644 --- a/poetry.lock +++ b/poetry.lock @@ -651,21 +651,19 @@ tests = ["asttokens", "littleutils", "pytest", "rich"] [[package]] name = "filelock" -version = "3.12.3" +version = "3.12.4" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.12.3-py3-none-any.whl", hash = "sha256:f067e40ccc40f2b48395a80fcbd4728262fab54e232e090a4063ab804179efeb"}, - {file = "filelock-3.12.3.tar.gz", hash = "sha256:0ecc1dd2ec4672a10c8550a8182f1bd0c0a5088470ecd5a125e45f49472fac3d"}, + {file = "filelock-3.12.4-py3-none-any.whl", hash = "sha256:08c21d87ded6e2b9da6728c3dff51baf1dcecf973b768ef35bcbc3447edb9ad4"}, + {file = "filelock-3.12.4.tar.gz", hash = "sha256:2e6f249f1f3654291606e046b09f1fd5eac39b360664c27f5aad072012f8bcbd"}, ] -[package.dependencies] -typing-extensions = {version = ">=4.7.1", markers = "python_version < \"3.11\""} - [package.extras] docs = ["furo (>=2023.7.26)", "sphinx (>=7.1.2)", "sphinx-autodoc-typehints (>=1.24)"] testing = ["covdefaults (>=2.3)", "coverage (>=7.3)", "diff-cover (>=7.7)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "pytest-timeout (>=2.1)"] +typing = ["typing-extensions (>=4.7.1)"] [[package]] name = "frozenlist" @@ -2507,6 +2505,25 @@ dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyl doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +[[package]] +name = "segment_anything" +version = "1.0" +description = "" +optional = true +python-versions = "*" +files = [] +develop = false + +[package.extras] +all = ["matplotlib", "onnx", "onnxruntime", "opencv-python", "pycocotools"] +dev = ["black", "flake8", "isort", "mypy"] + +[package.source] +type = "git" +url = "https://github.com/facebookresearch/segment-anything" +reference = "HEAD" +resolved_reference = "6fdee8f2727f4506cfbbe553e23b895e27956588" + [[package]] name = "sentry-sdk" version = "1.31.0" @@ -3289,11 +3306,11 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] [extras] -conversion = ["diffusers", "transformers"] -test = ["diffusers", "invisible-watermark", "piq", "transformers"] +conversion = ["diffusers", "segment-anything", "transformers"] +test = ["diffusers", "invisible-watermark", "piq", "segment-anything", "transformers"] training = ["bitsandbytes", "datasets", "loguru", "prodigyopt", "pydantic", "scipy", "tomli", "torchvision", "wandb"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "b2646e04015b05eba32e6556b4aa37d3c5c31218c09fd0ea53d7540af558b7de" +content-hash = "6e35f01f2fc8611203da972939170d18371182443e4c6f45b4cfc40b4e785dff" diff --git a/pyproject.toml b/pyproject.toml index 7c4e083..4c7bde5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,11 +29,12 @@ diffusers = {version = "^0.18.0", optional = true} transformers = {version = "^4.27.4", optional = true} piq = {version = "^0.7.1", optional = true} invisible-watermark = {version = "^0.2.0", optional = true} +segment-anything = {git = "https://github.com/facebookresearch/segment-anything", optional = true} [tool.poetry.extras] training = ["datasets", "tomli", "wandb", "loguru", "bitsandbytes", "prodigyopt", "pydantic", "scipy", "torchvision"] -conversion = ["diffusers", "transformers"] -test = ["diffusers", "transformers", "piq", "invisible-watermark"] +conversion = ["diffusers", "transformers", "segment-anything"] +test = ["diffusers", "transformers", "piq", "invisible-watermark", "segment-anything"] [tool.poetry.group.dev.dependencies] black = "^23.1.0" diff --git a/scripts/conversion/convert_segment_anything.py b/scripts/conversion/convert_segment_anything.py new file mode 100644 index 0000000..fdfd9bd --- /dev/null +++ b/scripts/conversion/convert_segment_anything.py @@ -0,0 +1,236 @@ +import argparse +import types +from typing import Any, Callable, cast +import torch +import torch.nn as nn +from torch import Tensor + +import refiners.fluxion.layers as fl +from refiners.fluxion.model_converter import ModelConverter +from refiners.fluxion.utils import manual_seed, save_to_safetensors +from refiners.foundationals.segment_anything.image_encoder import SAMViTH +from refiners.foundationals.segment_anything.prompt_encoder import PointEncoder, MaskEncoder + +from segment_anything import build_sam_vit_h # type: ignore +from segment_anything.modeling.common import LayerNorm2d # type: ignore + +from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder + + +class FacebookSAM(nn.Module): + image_encoder: nn.Module + prompt_encoder: nn.Module + mask_decoder: nn.Module + + +build_sam_vit_h = cast(Callable[[], FacebookSAM], build_sam_vit_h) + + +assert issubclass(LayerNorm2d, nn.Module) +custom_layers = {LayerNorm2d: fl.LayerNorm2d} + + +class Args(argparse.Namespace): + source_path: str + output_path: str + half: bool + verbose: bool + + +def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]: + state_dict: dict[str, Tensor] = { + "no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore + } + + refiners_mask_encoder = MaskEncoder() + # TODO: handle other weights + refiners_mask_encoder.load_state_dict(state_dict=state_dict, strict=False) + + return state_dict + + +def convert_point_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]: + manual_seed(seed=0) + point_embeddings: list[Tensor] = [pe.weight for pe in prompt_encoder.point_embeddings] + [prompt_encoder.not_a_point_embed.weight] # type: ignore + pe = prompt_encoder.pe_layer.positional_encoding_gaussian_matrix # type: ignore + assert isinstance(pe, Tensor) + state_dict: dict[str, Tensor] = { + "Residual.Chain.PointTypeEmbedding.weight": nn.Parameter(data=torch.cat(tensors=point_embeddings, dim=0)), + "CoordinateEncoder.Linear.weight": nn.Parameter(data=pe.T.contiguous()), + } + + refiners_prompt_encoder = PointEncoder() + refiners_prompt_encoder.load_state_dict(state_dict=state_dict) + + return state_dict + + +def convert_vit(vit: nn.Module) -> dict[str, Tensor]: + manual_seed(seed=0) + refiners_sam_vit_h = SAMViTH() + + converter = ModelConverter( + source_model=vit, + target_model=refiners_sam_vit_h, + custom_layer_mapping=custom_layers, # type: ignore + ) + converter.skip_init_check = True + + x = torch.randn(1, 3, 1024, 1024) + mapping = converter.map_state_dicts(source_args=(x,)) + assert mapping + + mapping["PositionalEncoder.Chain.Parameter.parameter"] = "pos_embed" + + target_state_dict = refiners_sam_vit_h.state_dict() + del target_state_dict["PositionalEncoder.Chain.Parameter.parameter"] + + source_state_dict = vit.state_dict() + pos_embed = source_state_dict["pos_embed"] + del source_state_dict["pos_embed"] + + target_rel_keys = [ + ( + f"Transformer.TransformerLayer_{i}.Residual_1.Chain.FusedSelfAttention.RelativePositionAttention.horizontal_embedding", + f"Transformer.TransformerLayer_{i}.Residual_1.Chain.FusedSelfAttention.RelativePositionAttention.vertical_embedding", + ) + for i in range(1, 33) + ] + source_rel_keys = [(f"blocks.{i}.attn.rel_pos_w", f"blocks.{i}.attn.rel_pos_h") for i in range(32)] + + rel_items: dict[str, Tensor] = {} + + for (key_w, key_h), (target_key_w, target_key_h) in zip(source_rel_keys, target_rel_keys): + rel_items[target_key_w] = source_state_dict[key_w] + rel_items[target_key_h] = source_state_dict[key_h] + del source_state_dict[key_w] + del source_state_dict[key_h] + del target_state_dict[target_key_w] + del target_state_dict[target_key_h] + + converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage] + source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping + ) + + converted_source["PositionalEncoder.Chain.Parameter.parameter"] = pos_embed # type: ignore + converted_source.update(rel_items) + + refiners_sam_vit_h.load_state_dict(state_dict=converted_source) + assert converter.compare_models((x,), threshold=1e-3) + + return converted_source + + +def convert_mask_decoder(mask_decoder: nn.Module) -> dict[str, Tensor]: + manual_seed(seed=0) + + refiners_mask_decoder = MaskDecoder() + + image_embedding = torch.randn(1, 256, 64, 64) + dense_positional_embedding = torch.randn(1, 256, 64, 64) + point_embedding = torch.randn(1, 3, 256) + mask_embedding = torch.randn(1, 256, 64, 64) + + import refiners.fluxion.layers as fl + from segment_anything.modeling.common import LayerNorm2d # type: ignore + + assert issubclass(LayerNorm2d, nn.Module) + custom_layers = {LayerNorm2d: fl.LayerNorm2d} + + converter = ModelConverter( + source_model=mask_decoder, + target_model=refiners_mask_decoder, + custom_layer_mapping=custom_layers, # type: ignore + ) + + inputs = { + "image_embeddings": image_embedding, + "image_pe": dense_positional_embedding, + "sparse_prompt_embeddings": point_embedding, + "dense_prompt_embeddings": mask_embedding, + "multimask_output": True, + } + + refiners_mask_decoder.set_image_embedding(image_embedding) + refiners_mask_decoder.set_point_embedding(point_embedding) + refiners_mask_decoder.set_mask_embedding(mask_embedding) + refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding) + + mapping = converter.map_state_dicts(source_args=inputs, target_args={}) + assert mapping is not None + mapping["IOUMaskEncoder"] = "iou_token" + + state_dict = converter._convert_state_dict(source_state_dict=mask_decoder.state_dict(), target_state_dict=refiners_mask_decoder.state_dict(), state_dict_mapping=mapping) # type: ignore + state_dict["IOUMaskEncoder.weight"] = torch.cat(tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0) # type: ignore + + refiners_mask_decoder.load_state_dict(state_dict=state_dict) + + refiners_mask_decoder.set_image_embedding(image_embedding) + refiners_mask_decoder.set_point_embedding(point_embedding) + refiners_mask_decoder.set_mask_embedding(mask_embedding) + refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding) + + # Perform (1) upscaling then (2) mask prediction in this order (= like in the official implementation) to make + # `compare_models` happy (MaskPrediction's Matmul runs those in the reverse order by default) + matmul = refiners_mask_decoder.ensure_find(fl.Matmul) + + def forward_swapped_order(self: Any, *args: Any) -> Any: + y = self[1](*args) # (1) + x = self[0](*args) # (2) + return torch.matmul(input=x, other=y) + + matmul.forward = types.MethodType(forward_swapped_order, matmul) + + assert converter.compare_models(source_args=inputs, target_args={}, threshold=1e-3) + + return state_dict + + +def main() -> None: + parser = argparse.ArgumentParser(description="Converts a Segment Anything ViT model to a Refiners SAMViTH model") + parser.add_argument( + "--from", + type=str, + dest="source_path", + default="sam_vit_h_4b8939.pth", + # required=True, + help="Path to the Segment Anything model weights", + ) + parser.add_argument( + "--to", + type=str, + dest="output_path", + default="segment-anything-h.safetensors", + help="Output path for converted model (as safetensors).", + ) + parser.add_argument("--half", action="store_true", default=False, help="Convert to half precision. Default: False") + parser.add_argument( + "--verbose", + action="store_true", + default=False, + help="Prints additional information during conversion. Default: False", + ) + args = parser.parse_args(namespace=Args()) + + sam_h = build_sam_vit_h() # type: ignore + sam_h.load_state_dict(state_dict=torch.load(f=args.source_path)) # type: ignore + + vit_state_dict = convert_vit(vit=sam_h.image_encoder) + mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder) + point_encoder_state_dict = convert_point_encoder(prompt_encoder=sam_h.prompt_encoder) + mask_encoder_state_dict = convert_mask_encoder(prompt_encoder=sam_h.prompt_encoder) + + output_state_dict = { + **{".".join(("image_encoder", key)): value for key, value in vit_state_dict.items()}, + **{".".join(("mask_decoder", key)): value for key, value in mask_decoder_state_dict.items()}, + **{".".join(("point_encoder", key)): value for key, value in point_encoder_state_dict.items()}, + **{".".join(("mask_encoder", key)): value for key, value in mask_encoder_state_dict.items()}, + } + if args.half: + output_state_dict = {key: value.half() for key, value in output_state_dict.items()} + + save_to_safetensors(path=args.output_path, tensors=output_state_dict) + + +if __name__ == "__main__": + main() diff --git a/src/refiners/foundationals/segment_anything/__init__.py b/src/refiners/foundationals/segment_anything/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/refiners/foundationals/segment_anything/image_encoder.py b/src/refiners/foundationals/segment_anything/image_encoder.py new file mode 100644 index 0000000..4a4f7e7 --- /dev/null +++ b/src/refiners/foundationals/segment_anything/image_encoder.py @@ -0,0 +1,369 @@ +from torch import device as Device, dtype as DType, Tensor +from refiners.fluxion.context import Contexts +import refiners.fluxion.layers as fl +from refiners.fluxion.utils import pad +from torch import nn +import torch + + +class PatchEncoder(fl.Chain): + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int = 16, + use_bias: bool = True, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_size = patch_size + self.use_bias = use_bias + super().__init__( + fl.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=(self.patch_size, self.patch_size), + stride=(self.patch_size, self.patch_size), + use_bias=self.use_bias, + device=device, + dtype=dtype, + ), + fl.Permute(0, 2, 3, 1), + ) + + +class PositionalEncoder(fl.Residual): + def __init__( + self, + embedding_dim: int, + image_embedding_size: tuple[int, int], + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.embedding_dim = embedding_dim + self.image_embedding_size = image_embedding_size + super().__init__( + fl.Parameter( + 1, + image_embedding_size[0], + image_embedding_size[1], + embedding_dim, + device=device, + dtype=dtype, + ), + ) + + +class RelativePositionAttention(fl.WeightedModule): + def __init__( + self, + embedding_dim: int, + num_heads: int, + spatial_size: tuple[int, int], + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.head_dim = embedding_dim // num_heads + self.spatial_size = spatial_size + self.horizontal_embedding = nn.Parameter( + data=torch.zeros(2 * spatial_size[0] - 1, self.head_dim, device=device, dtype=dtype) + ) + self.vertical_embedding = nn.Parameter( + data=torch.zeros(2 * spatial_size[1] - 1, self.head_dim, device=device, dtype=dtype) + ) + + @property + def device(self) -> Device: + return self.horizontal_embedding.device + + @property + def dtype(self) -> DType: + return self.horizontal_embedding.dtype + + def forward(self, x: Tensor) -> Tensor: + batch, height, width, _ = x.shape + x = ( + x.reshape(batch, width * height, 3, self.num_heads, -1) + .permute(2, 0, 3, 1, 4) + .reshape(3, batch * self.num_heads, width * height, -1) + ) + query, key, value = x.unbind(dim=0) + horizontal_relative_embedding, vertical_relative_embedding = self.compute_relative_embedding(x=query) + attention = (query * self.head_dim**-0.5) @ key.transpose(dim0=-2, dim1=-1) + # Order of operations is important here + attention = ( + (attention.reshape(-1, height, width, height, width) + vertical_relative_embedding) + + horizontal_relative_embedding + ).reshape(attention.shape) + attention = attention.softmax(dim=-1) + attention = attention @ value + attention = ( + attention.reshape(batch, self.num_heads, height, width, -1) + .permute(0, 2, 3, 1, 4) + .reshape(batch, height, width, -1) + ) + return attention + + def compute_relative_coords(self, size: int) -> Tensor: + x, y = torch.meshgrid(torch.arange(end=size), torch.arange(end=size), indexing="ij") + return x - y + size - 1 + + def compute_relative_embedding(self, x: Tensor) -> tuple[Tensor, Tensor]: + width, height = self.spatial_size + horizontal_coords = self.compute_relative_coords(size=width) + vertical_coords = self.compute_relative_coords(size=height) + horizontal_positional_embedding = self.horizontal_embedding[horizontal_coords] + vertical_positional_embedding = self.vertical_embedding[vertical_coords] + x = x.reshape(x.shape[0], width, height, -1) + horizontal_relative_embedding = torch.einsum("bhwc,wkc->bhwk", x, horizontal_positional_embedding).unsqueeze( + dim=-2 + ) + vertical_relative_embedding = torch.einsum("bhwc,hkc->bhwk", x, vertical_positional_embedding).unsqueeze(dim=-1) + + return horizontal_relative_embedding, vertical_relative_embedding + + +class FusedSelfAttention(fl.Chain): + def __init__( + self, + embedding_dim: int = 768, + spatial_size: tuple[int, int] = (64, 64), + num_heads: int = 1, + use_bias: bool = True, + is_causal: bool = False, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + assert ( + embedding_dim % num_heads == 0 + ), f"Embedding dim (embedding_dim={embedding_dim}) must be divisible by num heads (num_heads={num_heads})" + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.use_bias = use_bias + self.is_causal = is_causal + super().__init__( + fl.Linear( + in_features=self.embedding_dim, + out_features=3 * self.embedding_dim, + bias=self.use_bias, + device=device, + dtype=dtype, + ), + RelativePositionAttention( + embedding_dim=self.embedding_dim, + num_heads=self.num_heads, + spatial_size=spatial_size, + device=device, + dtype=dtype, + ), + fl.Linear( + in_features=self.embedding_dim, + out_features=self.embedding_dim, + bias=True, + device=device, + dtype=dtype, + ), + ) + + +class FeedForward(fl.Chain): + def __init__( + self, + embedding_dim: int, + feedforward_dim: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.embedding_dim = embedding_dim + self.feedforward_dim = feedforward_dim + super().__init__( + fl.Linear( + in_features=self.embedding_dim, + out_features=self.feedforward_dim, + bias=True, + device=device, + dtype=dtype, + ), + fl.GeLU(), + fl.Linear( + in_features=self.feedforward_dim, + out_features=self.embedding_dim, + bias=True, + device=device, + dtype=dtype, + ), + ) + + +class WindowPartition(fl.ContextModule): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: Tensor) -> Tensor: + batch, height, width, channels = x.shape + context = self.use_context(context_name="window_partition") + context.update({"original_height": height, "original_width": width}) + window_size = context["window_size"] + padding_height = (window_size - height % window_size) % window_size + padding_width = (window_size - width % window_size) % window_size + if padding_height > 0 or padding_width > 0: + x = pad(x=x, pad=(0, 0, 0, padding_width, 0, padding_height)) + padded_height, padded_width = height + padding_height, width + padding_width + context.update({"padded_height": padded_height, "padded_width": padded_width}) + x = x.view(batch, padded_height // window_size, window_size, padded_width // window_size, window_size, channels) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, channels) + return windows + + +class WindowMerge(fl.ContextModule): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: Tensor) -> Tensor: + context = self.use_context(context_name="window_partition") + window_size = context["window_size"] + padded_height, padded_width = context["padded_height"], context["padded_width"] + original_height, original_width = context["original_height"], context["original_width"] + batch_size = x.shape[0] // (padded_height * padded_width // window_size // window_size) + x = x.view(batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, padded_height, padded_width, -1) + if padded_height > original_height or padded_width > original_width: + x = x[:, :original_height, :original_width, :].contiguous() + return x + + +class TransformerLayer(fl.Chain): + def __init__( + self, + embedding_dim: int, + num_heads: int, + feedforward_dim: int, + image_embedding_size: tuple[int, int], + window_size: int | None = None, + layer_norm_eps: float = 1e-6, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.feedforward_dim = feedforward_dim + self.window_size = window_size + self.layer_norm_eps = layer_norm_eps + self.image_embedding_size = image_embedding_size + attention_spatial_size = (window_size, window_size) if window_size is not None else image_embedding_size + reshape_or_merge = ( + WindowMerge() + if self.window_size is not None + else fl.Reshape(self.image_embedding_size[0], self.image_embedding_size[1], embedding_dim) + ) + super().__init__( + fl.Residual( + fl.LayerNorm(normalized_shape=embedding_dim, eps=self.layer_norm_eps, device=device, dtype=dtype), + WindowPartition() if self.window_size is not None else fl.Identity(), + FusedSelfAttention( + embedding_dim=embedding_dim, + num_heads=num_heads, + spatial_size=attention_spatial_size, + device=device, + dtype=dtype, + ), + reshape_or_merge, + ), + fl.Residual( + fl.LayerNorm(normalized_shape=embedding_dim, eps=self.layer_norm_eps, device=device, dtype=dtype), + FeedForward(embedding_dim=embedding_dim, feedforward_dim=feedforward_dim, device=device, dtype=dtype), + ), + ) + + def init_context(self) -> Contexts: + return {"window_partition": {"window_size": self.window_size}} + + +class Neck(fl.Chain): + def __init__(self, in_channels: int = 768, device: Device | str | None = None, dtype: DType | None = None) -> None: + self.in_channels = in_channels + super().__init__( + fl.Permute(0, 3, 1, 2), + fl.Conv2d( + in_channels=self.in_channels, + out_channels=256, + kernel_size=1, + use_bias=False, + device=device, + dtype=dtype, + ), + fl.LayerNorm2d(channels=256, device=device, dtype=dtype), + fl.Conv2d( + in_channels=256, + out_channels=256, + kernel_size=3, + padding=1, + use_bias=False, + device=device, + dtype=dtype, + ), + fl.LayerNorm2d(channels=256, device=device, dtype=dtype), + ) + + +class Transformer(fl.Chain): + pass + + +class SAMViT(fl.Chain): + def __init__( + self, + embedding_dim: int, + num_layers: int, + num_heads: int, + global_attention_indices: tuple[int, ...] | None = None, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.embedding_dim = embedding_dim + self.num_layers = num_layers + self.num_heads = num_heads + self.image_size = (1024, 1024) + self.patch_size = 16 + self.window_size = 14 + self.image_embedding_size = (self.image_size[0] // self.patch_size, self.image_size[1] // self.patch_size) + self.feed_forward_dim = 4 * self.embedding_dim + self.global_attention_indices = global_attention_indices or tuple() + super().__init__( + PatchEncoder( + in_channels=3, out_channels=embedding_dim, patch_size=self.patch_size, device=device, dtype=dtype + ), + PositionalEncoder( + embedding_dim=embedding_dim, image_embedding_size=self.image_embedding_size, device=device, dtype=dtype + ), + Transformer( + TransformerLayer( + embedding_dim=embedding_dim, + num_heads=num_heads, + feedforward_dim=self.feed_forward_dim, + window_size=self.window_size if i not in self.global_attention_indices else None, + image_embedding_size=self.image_embedding_size, + device=device, + dtype=dtype, + ) + for i in range(num_layers) + ), + Neck(in_channels=embedding_dim, device=device, dtype=dtype), + ) + + +class SAMViTH(SAMViT): + def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: + super().__init__( + embedding_dim=1280, + num_layers=32, + num_heads=16, + global_attention_indices=(7, 15, 23, 31), + device=device, + dtype=dtype, + ) diff --git a/src/refiners/foundationals/segment_anything/mask_decoder.py b/src/refiners/foundationals/segment_anything/mask_decoder.py new file mode 100644 index 0000000..5502bca --- /dev/null +++ b/src/refiners/foundationals/segment_anything/mask_decoder.py @@ -0,0 +1,264 @@ +import refiners.fluxion.layers as fl +from torch import device as Device, dtype as DType, Tensor, nn +import torch + +from refiners.foundationals.segment_anything.transformer import ( + SparseCrossDenseAttention, + TwoWayTranformerLayer, +) +from refiners.fluxion.context import Contexts + + +class EmbeddingsAggregator(fl.ContextModule): + def __init__(self, num_output_mask: int = 3) -> None: + super().__init__() + self.num_mask_tokens = num_output_mask + + def forward(self, iou_mask_tokens: Tensor) -> Tensor: + mask_decoder = self.ensure_parent + mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder") + image_embedding = mask_decoder_context["image_embedding"] + point_embedding = mask_decoder_context["point_embedding"] + mask_embedding = mask_decoder_context["mask_embedding"] + dense_positional_embedding = mask_decoder_context["dense_positional_embedding"] + + sparse_embedding = torch.cat(tensors=(iou_mask_tokens, point_embedding), dim=1) + dense_embedding = (image_embedding + mask_embedding).flatten(start_dim=2).transpose(1, 2) + if dense_positional_embedding.shape != dense_embedding.shape: + dense_positional_embedding = dense_positional_embedding.flatten(start_dim=2).transpose(1, 2) + + mask_decoder_context.update( + { + "dense_embedding": dense_embedding, + "dense_positional_embedding": dense_positional_embedding, + "sparse_embedding": sparse_embedding, + } + ) + mask_decoder.set_context(context="mask_decoder", value=mask_decoder_context) + + return sparse_embedding + + +class Transformer(fl.Chain): + pass + + +class Hypernetworks(fl.Concatenate): + def __init__( + self, + embedding_dim: int = 256, + num_layers: int = 3, + num_mask_tokens: int = 3, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.num_layers = num_layers + self.num_mask_tokens = num_mask_tokens + + super().__init__( + *[ + fl.Chain( + fl.Slicing(dim=1, start=i + 1, length=1), + fl.MultiLinear( + input_dim=embedding_dim, + output_dim=embedding_dim // 8, + inner_dim=embedding_dim, + num_layers=num_layers, + device=device, + dtype=dtype, + ), + ) + for i in range(num_mask_tokens + 1) + ], + dim=1, + ) + + +class DenseEmbeddingUpscaling(fl.Chain): + def __init__( + self, + embedding_dim: int = 256, + dense_embedding_side_dim: int = 64, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.dense_embedding_side_dim = dense_embedding_side_dim + + super().__init__( + fl.UseContext(context="mask_decoder", key="dense_embedding"), + fl.Transpose(dim0=1, dim1=2), + fl.Reshape(embedding_dim, dense_embedding_side_dim, dense_embedding_side_dim), + fl.ConvTranspose2d( + in_channels=embedding_dim, + out_channels=embedding_dim // 4, + kernel_size=2, + stride=2, + device=device, + dtype=dtype, + ), + fl.LayerNorm2d(channels=embedding_dim // 4, device=device, dtype=dtype), + fl.GeLU(), + fl.ConvTranspose2d( + in_channels=embedding_dim // 4, + out_channels=embedding_dim // 8, + kernel_size=2, + stride=2, + device=device, + dtype=dtype, + ), + fl.GeLU(), + fl.Flatten(start_dim=2), + ) + + +class IOUMaskEncoder(fl.WeightedModule): + def __init__( + self, + embedding_dim: int = 256, + num_mask_tokens: int = 4, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.num_mask_tokens = num_mask_tokens + # aka prompt tokens + output token (for IoU scores prediction) + self.weight = nn.Parameter(data=torch.randn(num_mask_tokens + 1, embedding_dim, device=device, dtype=dtype)) + + def forward(self) -> Tensor: + return self.weight.unsqueeze(dim=0) + + +class MaskPrediction(fl.Chain): + def __init__( + self, + embedding_dim: int, + num_mask_tokens: int, + num_layers: int = 3, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.embedding_dim = embedding_dim + self.num_mask_tokens = num_mask_tokens + self.num_layers = num_layers + super().__init__( + fl.Matmul( + input=Hypernetworks( + embedding_dim=embedding_dim, + num_layers=num_layers, + num_mask_tokens=num_mask_tokens, + device=device, + dtype=dtype, + ), + other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype), + ), + fl.Slicing(dim=1, start=1, length=num_mask_tokens), + fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim), + ) + + +class IOUPrediction(fl.Chain): + def __init__( + self, + embedding_dim: int, + num_layers: int, + num_mask_tokens: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.embedding_dim = embedding_dim + self.num_layers = num_layers + super().__init__( + fl.Slicing(dim=1, start=0, length=1), + fl.Squeeze(dim=0), + fl.MultiLinear( + input_dim=embedding_dim, + output_dim=num_mask_tokens + 1, + inner_dim=embedding_dim, + num_layers=num_layers, + device=device, + dtype=dtype, + ), + fl.Slicing(dim=-1, start=1, length=num_mask_tokens), + ) + + +class MaskDecoder(fl.Chain): + def __init__( + self, + embedding_dim: int = 256, + feed_forward_dim: int = 2048, + num_layers: int = 2, + num_output_mask: int = 3, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.num_mask_tokens = num_output_mask + self.feed_forward_dim = feed_forward_dim + self.num_layers = num_layers + + super().__init__( + IOUMaskEncoder( + embedding_dim=embedding_dim, num_mask_tokens=num_output_mask + 1, device=device, dtype=dtype + ), + EmbeddingsAggregator(num_output_mask=num_output_mask), + Transformer( + *( + TwoWayTranformerLayer( + embedding_dim=embedding_dim, + num_heads=8, + feed_forward_dim=feed_forward_dim, + use_residual_self_attention=i > 0, + device=device, + dtype=dtype, + ) + for i in range(num_layers) + ), + SparseCrossDenseAttention(embedding_dim=embedding_dim, device=device, dtype=dtype), + fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype), + ), + fl.Parallel( + MaskPrediction( + embedding_dim=embedding_dim, num_mask_tokens=num_output_mask, device=device, dtype=dtype + ), + IOUPrediction( + embedding_dim=embedding_dim, + num_layers=3, + num_mask_tokens=num_output_mask, + device=device, + dtype=dtype, + ), + ), + ) + + def init_context(self) -> Contexts: + return { + "mask_decoder": { + "image_embedding": None, + "point_embedding": None, + "mask_embedding": None, + "dense_positional_embedding": None, + } + } + + def set_image_embedding(self, image_embedding: Tensor) -> None: + mask_decoder_context = self.use_context(context_name="mask_decoder") + mask_decoder_context["image_embedding"] = image_embedding + + def set_point_embedding(self, point_embedding: Tensor) -> None: + mask_decoder_context = self.use_context(context_name="mask_decoder") + mask_decoder_context["point_embedding"] = point_embedding + + def set_mask_embedding(self, mask_embedding: Tensor) -> None: + mask_decoder_context = self.use_context(context_name="mask_decoder") + mask_decoder_context["mask_embedding"] = mask_embedding + + def set_dense_positional_embedding(self, dense_positional_embedding: Tensor) -> None: + mask_decoder_context = self.use_context(context_name="mask_decoder") + mask_decoder_context["dense_positional_embedding"] = dense_positional_embedding diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py new file mode 100644 index 0000000..1c841c2 --- /dev/null +++ b/src/refiners/foundationals/segment_anything/model.py @@ -0,0 +1,166 @@ +from dataclasses import dataclass +from typing import Sequence +from PIL import Image +from torch import device as Device, dtype as DType, Tensor +import numpy as np +import torch +import refiners.fluxion.layers as fl +from refiners.fluxion.utils import image_to_tensor, normalize, pad, interpolate +from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH +from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder +from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder + + +@dataclass +class ImageEmbedding: + features: Tensor + original_image_size: tuple[int, int] # (height, width) + + +class SegmentAnything(fl.Module): + mask_threshold: float = 0.0 + + def __init__( + self, + image_encoder: SAMViT, + point_encoder: PointEncoder, + mask_encoder: MaskEncoder, + mask_decoder: MaskDecoder, + device: Device | str = "cpu", + dtype: DType = torch.float32, + ) -> None: + super().__init__() + self.device: Device = device if isinstance(device, Device) else Device(device=device) + self.dtype = dtype + self.image_encoder = image_encoder.to(device=self.device, dtype=self.dtype) + self.point_encoder = point_encoder.to(device=self.device, dtype=self.dtype) + self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype) + self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype) + + def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding: + original_size = (image.height, image.width) + target_size = self.compute_target_size(original_size) + return ImageEmbedding( + features=self.image_encoder(self.preprocess_image(image=image, target_size=target_size)), + original_image_size=original_size, + ) + + def predict( + self, + input: Image.Image | ImageEmbedding, + foreground_points: Sequence[tuple[float, float]] | None = None, + background_points: Sequence[tuple[float, float]] | None = None, + box_points: Sequence[Sequence[tuple[float, float]]] | None = None, + masks: Sequence[Image.Image] | None = None, + binarize: bool = True, + ) -> tuple[Tensor, Tensor, Tensor]: + if isinstance(input, ImageEmbedding): + original_size = input.original_image_size + target_size = self.compute_target_size(original_size) + image_embedding = input.features + else: + original_size = (input.height, input.width) + target_size = self.compute_target_size(original_size) + image_embedding = self.image_encoder(self.preprocess_image(image=input, target_size=target_size)) + + coordinates, type_mask = self.point_encoder.points_to_tensor( + foreground_points=foreground_points, + background_points=background_points, + box_points=box_points, + ) + self.point_encoder.set_type_mask(type_mask=type_mask) + + if masks is not None: + mask_tensor = torch.stack( + tensors=[image_to_tensor(image=mask, device=self.device, dtype=self.dtype) for mask in masks] + ) + mask_embedding = self.mask_encoder(mask_tensor) + else: + mask_embedding = self.mask_encoder.get_no_mask_dense_embedding( + image_embedding_size=self.image_encoder.image_embedding_size + ) + point_embedding = self.point_encoder( + self.normalize(coordinates, target_size=target_size, original_size=original_size) + ) + dense_positional_embedding = self.point_encoder.get_dense_positional_embedding( + image_embedding_size=self.image_encoder.image_embedding_size + ) + + self.mask_decoder.set_image_embedding(image_embedding=image_embedding) + self.mask_decoder.set_mask_embedding(mask_embedding=mask_embedding) + self.mask_decoder.set_point_embedding(point_embedding=point_embedding) + self.mask_decoder.set_dense_positional_embedding(dense_positional_embedding=dense_positional_embedding) + + low_res_masks, iou_predictions = self.mask_decoder() + + high_res_masks = self.postprocess_masks( + masks=low_res_masks, target_size=target_size, original_size=original_size + ) + + if binarize: + high_res_masks = high_res_masks > self.mask_threshold + + return high_res_masks, iou_predictions, low_res_masks + + @property + def image_size(self) -> int: + w, h = self.image_encoder.image_size + assert w == h + return w + + def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]: + oldh, oldw = size + scale = self.image_size * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) + + def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor: + h, w = target_size + padh = self.image_size - h + padw = self.image_size - w + image_tensor = torch.tensor( + np.array(image.resize((w, h), resample=Image.Resampling.BILINEAR)).astype(np.float32).transpose(2, 0, 1), + device=self.device, + dtype=self.dtype, + ).unsqueeze(0) + return pad( + normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]), (0, padw, 0, padh) + ) + + def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor: + coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size + coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size + return coordinates + + def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor: + masks = interpolate(masks, factor=torch.Size((self.image_size, self.image_size)), mode="bilinear") + masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time + masks = interpolate(masks, factor=torch.Size(original_size), mode="bilinear") + return masks + + +class SegmentAnythingH(SegmentAnything): + def __init__( + self, + image_encoder: SAMViTH | None = None, + point_encoder: PointEncoder | None = None, + mask_encoder: MaskEncoder | None = None, + mask_decoder: MaskDecoder | None = None, + device: Device | str = "cpu", + dtype: DType = torch.float32, + ) -> None: + image_encoder = image_encoder or SAMViTH() + point_encoder = point_encoder or PointEncoder() + mask_encoder = mask_encoder or MaskEncoder() + mask_decoder = mask_decoder or MaskDecoder() + + super().__init__( + image_encoder=image_encoder, + point_encoder=point_encoder, + mask_encoder=mask_encoder, + mask_decoder=mask_decoder, + device=device, + dtype=dtype, + ) diff --git a/src/refiners/foundationals/segment_anything/prompt_encoder.py b/src/refiners/foundationals/segment_anything/prompt_encoder.py new file mode 100644 index 0000000..c803d46 --- /dev/null +++ b/src/refiners/foundationals/segment_anything/prompt_encoder.py @@ -0,0 +1,190 @@ +from enum import Enum, auto +from collections.abc import Sequence +from torch import device as Device, dtype as DType, Tensor, nn +import torch +from jaxtyping import Float, Int +import refiners.fluxion.layers as fl +from refiners.fluxion.context import Contexts + + +class CoordinateEncoder(fl.Chain): + def __init__( + self, + num_positional_features: int = 64, + scale: float = 1, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.num_positional_features = num_positional_features + self.scale = scale + + super().__init__( + fl.Multiply(scale=2, bias=-1), + fl.Linear(in_features=2, out_features=num_positional_features, bias=False, device=device, dtype=dtype), + fl.Multiply(scale=2 * torch.pi * self.scale), + fl.Concatenate(fl.Sin(), fl.Cos(), dim=-1), + ) + + +class PointType(Enum): + BACKGROUND = auto() + FOREGROUND = auto() + BOX_TOP_LEFT = auto() + BOX_BOTTOM_RIGHT = auto() + NOT_A_POINT = auto() + + +class PointTypeEmbedding(fl.WeightedModule, fl.ContextModule): + def __init__(self, embedding_dim: int, device: Device | str | None = None, dtype: DType | None = None) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.weight = nn.Parameter(data=torch.randn(len(PointType), self.embedding_dim, device=device, dtype=dtype)) + + def forward(self, type_mask: Int[Tensor, "1 num_points"]) -> Float[Tensor, "1 num_points embedding_dim"]: + assert isinstance(type_mask, Tensor), "type_mask must be a Tensor." + + embeddings = torch.zeros(*type_mask.shape, self.embedding_dim).to(device=type_mask.device) + for type_id in PointType: + mask = type_mask == type_id.value + embeddings[mask] = self.weight[type_id.value - 1] + + return embeddings + + +class PointEncoder(fl.Chain): + def __init__( + self, embedding_dim: int = 256, scale: float = 1, device: Device | str | None = None, dtype: DType | None = None + ) -> None: + assert embedding_dim % 2 == 0, "embedding_dim must be divisible by 2." + self.embedding_dim = embedding_dim + self.scale = scale + + super().__init__( + CoordinateEncoder(num_positional_features=embedding_dim // 2, scale=scale, device=device, dtype=dtype), + fl.Lambda(func=self.pad), + fl.Residual( + fl.UseContext(context="point_encoder", key="type_mask"), + PointTypeEmbedding(embedding_dim=embedding_dim, device=device, dtype=dtype), + ), + ) + + def pad(self, x: Tensor) -> Tensor: + type_mask: Tensor = self.use_context("point_encoder")["type_mask"] + if torch.any((type_mask == PointType.BOX_TOP_LEFT.value) | (type_mask == PointType.BOX_BOTTOM_RIGHT.value)): + # Some boxes have been passed: no need to pad in this case + return x + type_mask = torch.cat( + [type_mask, torch.full((type_mask.shape[0], 1), PointType.NOT_A_POINT.value, device=type_mask.device)], + dim=1, + ) + self.set_context(context="point_encoder", value={"type_mask": type_mask}) + return torch.cat([x, torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device)], dim=1) + + def init_context(self) -> Contexts: + return { + "point_encoder": { + "type_mask": None, + } + } + + def set_type_mask(self, type_mask: Int[Tensor, "1 num_points"]) -> None: + self.set_context(context="point_encoder", value={"type_mask": type_mask}) + + def get_dense_positional_embedding( + self, image_embedding_size: tuple[int, int] + ) -> Float[Tensor, "num_positional_features height width"]: + coordinate_encoder = self.ensure_find(layer_type=CoordinateEncoder) + height, width = image_embedding_size + grid = torch.ones((height, width), device=self.device, dtype=torch.float32) + y_embedding = grid.cumsum(dim=0) - 0.5 + x_embedding = grid.cumsum(dim=1) - 0.5 + y_embedding = y_embedding / height + x_embedding = x_embedding / width + positional_embedding = ( + coordinate_encoder(torch.stack(tensors=[x_embedding, y_embedding], dim=-1)) + .permute(2, 0, 1) + .unsqueeze(dim=0) + ) + return positional_embedding + + def points_to_tensor( + self, + foreground_points: Sequence[tuple[float, float]] | None = None, + background_points: Sequence[tuple[float, float]] | None = None, + not_a_points: Sequence[tuple[float, float]] | None = None, + box_points: Sequence[Sequence[tuple[float, float]]] | None = None, + ) -> tuple[Float[Tensor, "1 num_points 2"], Int[Tensor, "1 num_points"]]: + foreground_points = foreground_points or [] + background_points = background_points or [] + not_a_points = not_a_points or [] + box_points = box_points or [] + top_left_points = [box[0] for box in box_points] + bottom_right_points = [box[1] for box in box_points] + coordinates: list[Tensor] = [] + type_ids: list[Tensor] = [] + + # Must be in sync with PointType enum + for type_id, coords_seq in zip( + PointType, [background_points, foreground_points, top_left_points, bottom_right_points, not_a_points] + ): + if len(coords_seq) > 0: + coords_tensor = torch.tensor(data=list(coords_seq), dtype=torch.float, device=self.device) + coordinates.append(coords_tensor) + point_ids = torch.tensor(data=[type_id.value] * len(coords_seq), dtype=torch.int, device=self.device) + type_ids.append(point_ids) + + all_coordinates = torch.cat(tensors=coordinates, dim=0).unsqueeze(dim=0) + type_mask = torch.cat(tensors=type_ids, dim=0).unsqueeze(dim=0) + + return all_coordinates, type_mask + + +class MaskEncoder(fl.Chain): + def __init__( + self, + embedding_dim: int = 256, + intermediate_channels: int = 16, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.embedding_dim = embedding_dim + self.intermediate_channels = intermediate_channels + super().__init__( + fl.Conv2d( + in_channels=1, + out_channels=self.intermediate_channels // 4, + kernel_size=2, + stride=2, + device=device, + dtype=dtype, + ), + fl.LayerNorm2d(channels=self.intermediate_channels // 4, device=device, dtype=dtype), + fl.GeLU(), + fl.Conv2d( + in_channels=self.intermediate_channels // 4, + out_channels=self.intermediate_channels, + kernel_size=2, + stride=2, + device=device, + dtype=dtype, + ), + fl.LayerNorm2d(channels=self.intermediate_channels, device=device, dtype=dtype), + fl.GeLU(), + fl.Conv2d( + in_channels=self.intermediate_channels, + out_channels=self.embedding_dim, + kernel_size=1, + device=device, + dtype=dtype, + ), + ) + self.register_parameter( + "no_mask_embedding", nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype)) + ) + + def get_no_mask_dense_embedding( + self, image_embedding_size: tuple[int, int], batch_size: int = 1 + ) -> Float[Tensor, "batch embedding_dim image_embedding_height image_embedding_width"]: + return self.no_mask_embedding.reshape(1, -1, 1, 1).expand( + batch_size, -1, image_embedding_size[0], image_embedding_size[1] + ) diff --git a/src/refiners/foundationals/segment_anything/transformer.py b/src/refiners/foundationals/segment_anything/transformer.py new file mode 100644 index 0000000..4c72bab --- /dev/null +++ b/src/refiners/foundationals/segment_anything/transformer.py @@ -0,0 +1,157 @@ +from torch import dtype as DType, device as Device +import refiners.fluxion.layers as fl + + +class CrossAttention(fl.Attention): + def __init__( + self, + embedding_dim: int, + cross_embedding_dim: int | None = None, + num_heads: int = 1, + inner_dim: int | None = None, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + embedding_dim=embedding_dim, + key_embedding_dim=cross_embedding_dim, + num_heads=num_heads, + inner_dim=inner_dim, + is_optimized=False, + device=device, + dtype=dtype, + ) + self.cross_embedding_dim = cross_embedding_dim or embedding_dim + self.insert(index=0, module=fl.Parallel(fl.GetArg(index=0), fl.GetArg(index=1), fl.GetArg(index=1))) + + +class FeedForward(fl.Residual): + def __init__( + self, embedding_dim: int, feed_forward_dim: int, device: Device | str | None = None, dtype: DType | None = None + ) -> None: + self.embedding_dim = embedding_dim + self.feed_forward_dim = feed_forward_dim + super().__init__( + fl.Linear(in_features=embedding_dim, out_features=feed_forward_dim, device=device, dtype=dtype), + fl.ReLU(), + fl.Linear(in_features=feed_forward_dim, out_features=embedding_dim, device=device, dtype=dtype), + ) + + +class SparseSelfAttention(fl.Residual): + def __init__( + self, + embedding_dim: int, + inner_dim: int | None = None, + num_heads: int = 1, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + add_sparse_embedding = fl.Residual(fl.UseContext(context="mask_decoder", key="sparse_embedding")) + super().__init__( + fl.Parallel(add_sparse_embedding, add_sparse_embedding, fl.Identity()), + fl.Attention( + embedding_dim=embedding_dim, + inner_dim=inner_dim, + num_heads=num_heads, + is_optimized=False, + device=device, + dtype=dtype, + ), + ) + + +class SparseCrossDenseAttention(fl.Residual): + def __init__( + self, embedding_dim: int, num_heads: int = 8, device: Device | str | None = None, dtype: DType | None = None + ) -> None: + self.embedding_dim = embedding_dim + self.num_heads = num_heads + super().__init__( + fl.Parallel( + fl.Residual( + fl.UseContext(context="mask_decoder", key="sparse_embedding"), + ), + fl.Sum( + fl.UseContext(context="mask_decoder", key="dense_embedding"), + fl.UseContext(context="mask_decoder", key="dense_positional_embedding"), + ), + fl.UseContext(context="mask_decoder", key="dense_embedding"), + ), + fl.Attention( + embedding_dim=embedding_dim, + inner_dim=embedding_dim // 2, + num_heads=num_heads, + is_optimized=False, + device=device, + dtype=dtype, + ), + ) + + +class DenseCrossSparseAttention(fl.Chain): + def __init__( + self, embedding_dim: int, num_heads: int = 8, device: Device | str | None = None, dtype: DType | None = None + ) -> None: + super().__init__( + fl.Parallel( + fl.Sum( + fl.UseContext(context="mask_decoder", key="dense_embedding"), + fl.UseContext(context="mask_decoder", key="dense_positional_embedding"), + ), + fl.Residual( + fl.UseContext(context="mask_decoder", key="sparse_embedding"), + ), + fl.Identity(), + ), + fl.Attention( + embedding_dim=embedding_dim, + inner_dim=embedding_dim // 2, + num_heads=num_heads, + is_optimized=False, + device=device, + dtype=dtype, + ), + ) + + +class TwoWayTranformerLayer(fl.Chain): + def __init__( + self, + embedding_dim: int, + num_heads: int = 8, + feed_forward_dim: int = 2048, + use_residual_self_attention: bool = True, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.feed_forward_dim = feed_forward_dim + + self_attention = ( + SparseSelfAttention(embedding_dim=embedding_dim, num_heads=num_heads, device=device, dtype=dtype) + if use_residual_self_attention + else fl.SelfAttention( + embedding_dim=embedding_dim, num_heads=num_heads, is_optimized=False, device=device, dtype=dtype + ) + ) + + super().__init__( + self_attention, + fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype), + SparseCrossDenseAttention(embedding_dim=embedding_dim, num_heads=num_heads, device=device, dtype=dtype), + fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype), + FeedForward(embedding_dim=embedding_dim, feed_forward_dim=feed_forward_dim, device=device, dtype=dtype), + fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype), + fl.Passthrough( + fl.Sum( + fl.UseContext(context="mask_decoder", key="dense_embedding"), + DenseCrossSparseAttention( + embedding_dim=embedding_dim, num_heads=num_heads, device=device, dtype=dtype + ), + ), + fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype), + fl.SetContext(context="mask_decoder", key="dense_embedding"), + ), + ) diff --git a/tests/conftest.py b/tests/conftest.py index decf97f..b57459d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,3 +28,8 @@ def test_e2e_path() -> Path: @fixture(scope="session") def test_textual_inversion_path() -> Path: return PARENT_PATH / "foundationals" / "clip" / "test_concepts_ref" + + +@fixture(scope="session") +def test_sam_path() -> Path: + return PARENT_PATH / "foundationals" / "segment_anything" diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py new file mode 100644 index 0000000..b2b40bc --- /dev/null +++ b/tests/foundationals/segment_anything/test_sam.py @@ -0,0 +1,317 @@ +from math import isclose +from pathlib import Path +from typing import cast +from warnings import warn + +import pytest +import torch +import torch.nn as nn +import numpy as np + +from PIL import Image +from torch import Tensor +from refiners.fluxion import manual_seed +from refiners.fluxion.model_converter import ModelConverter + +from refiners.fluxion.utils import image_to_tensor +from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention +from refiners.foundationals.segment_anything.model import SegmentAnythingH +from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer +from tests.foundationals.segment_anything.utils import ( + FacebookSAM, + FacebookSAMPredictor, + SAMPrompt, + intersection_over_union, +) + +# See predictor_example.ipynb official notebook (note: mask_input is not yet properly supported) +PROMPTS: list[SAMPrompt] = [ + SAMPrompt(foreground_points=((500, 375),)), + SAMPrompt(background_points=((500, 375),)), + SAMPrompt(foreground_points=((500, 375), (1125, 625))), + SAMPrompt(foreground_points=((500, 375),), background_points=((1125, 625),)), + SAMPrompt(box_points=[[(425, 600), (700, 875)]]), + SAMPrompt(box_points=[[(425, 600), (700, 875)]], background_points=((575, 750),)), +] + + +@pytest.fixture(params=PROMPTS) +def prompt(request: pytest.FixtureRequest) -> SAMPrompt: + return request.param + + +@pytest.fixture +def one_prompt() -> SAMPrompt: + return PROMPTS[0] + + +@pytest.fixture(scope="module") +def facebook_sam_h_weights(test_weights_path: Path) -> Path: + sam_h_weights = test_weights_path / "sam_vit_h_4b8939.pth" + if not sam_h_weights.is_file(): + warn(f"could not find weights at {sam_h_weights}, skipping") + pytest.skip(allow_module_level=True) + return sam_h_weights + + +@pytest.fixture(scope="module") +def sam_h_weights(test_weights_path: Path) -> Path: + sam_h_weights = test_weights_path / "segment-anything-h.safetensors" + if not sam_h_weights.is_file(): + warn(f"could not find weights at {sam_h_weights}, skipping") + pytest.skip(allow_module_level=True) + return sam_h_weights + + +@pytest.fixture(scope="module") +def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> FacebookSAM: + from segment_anything import build_sam_vit_h # type: ignore + + sam_h = cast(FacebookSAM, build_sam_vit_h()) + sam_h.load_state_dict(state_dict=torch.load(f=facebook_sam_h_weights)) # type: ignore + return sam_h.to(device=test_device) + + +@pytest.fixture(scope="module") +def facebook_sam_h_predictor(facebook_sam_h: FacebookSAM) -> FacebookSAMPredictor: + from segment_anything import SamPredictor # type: ignore + from segment_anything.modeling import Sam # type: ignore + + predictor = SamPredictor(cast(Sam, facebook_sam_h)) + return cast(FacebookSAMPredictor, predictor) + + +@pytest.fixture(scope="module") +def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH: + sam_h = SegmentAnythingH(device=test_device) + # TODO: make strict=True when the MasKEncoder conversion is done + sam_h.load_from_safetensors(tensors_path=sam_h_weights, strict=False) + return sam_h + + +@pytest.fixture(scope="module") +def ref_path(test_sam_path: Path) -> Path: + return test_sam_path / "test_sam_ref" + + +@pytest.fixture +def truck(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "truck.jpg").convert("RGB") + + +@torch.no_grad() +def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None: + manual_seed(seed=0) + x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device) + + attention = cast(nn.Module, facebook_sam_h.image_encoder.blocks[0].attn) # type: ignore + + refiners_attention = FusedSelfAttention( + embedding_dim=1280, num_heads=16, spatial_size=(14, 14), device=facebook_sam_h.device + ) + refiners_attention.Linear_1.weight = attention.qkv.weight # type: ignore + refiners_attention.Linear_1.bias = attention.qkv.bias # type: ignore + refiners_attention.Linear_2.weight = attention.proj.weight # type: ignore + refiners_attention.Linear_2.bias = attention.proj.bias # type: ignore + refiners_attention.RelativePositionAttention.horizontal_embedding = attention.rel_pos_w + refiners_attention.RelativePositionAttention.vertical_embedding = attention.rel_pos_h + + y_1 = attention(x) + assert y_1.shape == x.shape + + y_2 = refiners_attention(x) + assert y_2.shape == x.shape + + assert torch.equal(input=y_1, other=y_2) + + +@torch.no_grad() +def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None: + image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device) + y_1 = facebook_sam_h.image_encoder(image_tensor) + y_2 = sam_h.image_encoder(image_tensor) + + assert torch.equal(input=y_1, other=y_2) + + +@torch.no_grad() +def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None: + facebook_prompt_encoder = facebook_sam_h.prompt_encoder + refiners_prompt_encoder = sam_h.point_encoder + + facebook_dense_pe: Tensor = cast(Tensor, facebook_prompt_encoder.get_dense_pe()) # type: ignore + refiners_dense_pe = refiners_prompt_encoder.get_dense_positional_embedding(image_embedding_size=(64, 64)) + + assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe) + + +@torch.no_grad() +def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None: + facebook_prompt_encoder = facebook_sam_h.prompt_encoder + refiners_prompt_encoder = sam_h.mask_encoder + + _, facebook_dense_pe = facebook_prompt_encoder(points=None, boxes=None, masks=None) + refiners_dense_pe = refiners_prompt_encoder.get_no_mask_dense_embedding(image_embedding_size=(64, 64)) + + assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe) + + +@torch.no_grad() +def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, prompt: SAMPrompt) -> None: + facebook_prompt_encoder = facebook_sam_h.prompt_encoder + refiners_prompt_encoder = sam_h.point_encoder + + facebook_sparse_pe, _ = facebook_prompt_encoder( + **prompt.facebook_prompt_encoder_kwargs(device=facebook_sam_h.device) + ) + + coordinates, type_mask = refiners_prompt_encoder.points_to_tensor(**prompt.__dict__) + # Shift to center of pixel + normalize in [0, 1] (see `_embed_points` in segment-anything official repo) + coordinates[:, :, 0] = (coordinates[:, :, 0] + 0.5) / 1024.0 + coordinates[:, :, 1] = (coordinates[:, :, 1] + 0.5) / 1024.0 + refiners_prompt_encoder.set_type_mask(type_mask=type_mask) + refiners_sparse_pe = refiners_prompt_encoder(coordinates) + + assert torch.equal(input=refiners_sparse_pe, other=facebook_sparse_pe) + + +@torch.no_grad() +def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None: + dense_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device) + dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device) + sparse_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device) + + refiners_layer = TwoWayTranformerLayer( + embedding_dim=256, feed_forward_dim=2048, num_heads=8, device=facebook_sam_h.device + ) + facebook_layer = facebook_sam_h.mask_decoder.transformer.layers[1] # type: ignore + assert isinstance(facebook_layer, nn.Module) + + refiners_layer.set_context( + context="mask_decoder", + value={ + "dense_embedding": dense_embedding, + "dense_positional_embedding": dense_positional_embedding, + "sparse_embedding": sparse_embedding, + }, + ) + facebook_inputs = { + "queries": sparse_embedding, + "keys": dense_embedding, + "query_pe": sparse_embedding, + "key_pe": dense_positional_embedding, + } + + converter = ModelConverter( + source_model=facebook_layer, + target_model=refiners_layer, + skip_output_check=True, # done below, manually + ) + + assert converter.run(source_args=facebook_inputs, target_args=(sparse_embedding,)) + + refiners_layer.set_context( + context="mask_decoder", + value={ + "dense_embedding": dense_embedding, + "dense_positional_embedding": dense_positional_embedding, + "sparse_embedding": sparse_embedding, + }, + ) + y_1 = facebook_layer(**facebook_inputs)[0] + y_2 = refiners_layer(sparse_embedding)[0] + + assert torch.equal(input=y_1, other=y_2) + + +@torch.no_grad() +def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None: + manual_seed(seed=0) + facebook_mask_decoder = facebook_sam_h.mask_decoder + refiners_mask_decoder = sam_h.mask_decoder + + image_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device) + dense_positional_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device) + point_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device) + mask_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device) + + import refiners.fluxion.layers as fl + from segment_anything.modeling.common import LayerNorm2d # type: ignore + + assert issubclass(LayerNorm2d, nn.Module) + custom_layers = {LayerNorm2d: fl.LayerNorm2d} + + converter = ModelConverter( + source_model=facebook_mask_decoder, + target_model=refiners_mask_decoder, + custom_layer_mapping=custom_layers, # type: ignore + ) + + inputs = { + "image_embeddings": image_embedding, + "image_pe": dense_positional_embedding, + "sparse_prompt_embeddings": point_embedding, + "dense_prompt_embeddings": mask_embedding, + "multimask_output": True, + } + + refiners_mask_decoder.set_image_embedding(image_embedding) + refiners_mask_decoder.set_point_embedding(point_embedding) + refiners_mask_decoder.set_mask_embedding(mask_embedding) + refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding) + + mapping = converter.map_state_dicts(source_args=inputs, target_args={}) + assert mapping is not None + mapping["IOUMaskEncoder"] = "iou_token" + + state_dict = converter._convert_state_dict(source_state_dict=facebook_mask_decoder.state_dict(), target_state_dict=refiners_mask_decoder.state_dict(), state_dict_mapping=mapping) # type: ignore + state_dict["IOUMaskEncoder.weight"] = torch.cat([facebook_mask_decoder.iou_token.weight, facebook_mask_decoder.mask_tokens.weight], dim=0) # type: ignore + refiners_mask_decoder.load_state_dict(state_dict=state_dict) + + facebook_output = facebook_mask_decoder(**inputs) + + refiners_mask_decoder.set_image_embedding(image_embedding) + refiners_mask_decoder.set_point_embedding(point_embedding) + refiners_mask_decoder.set_mask_embedding(mask_embedding) + refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding) + mask_prediction, iou_prediction = refiners_mask_decoder() + + facebook_masks = facebook_output[0] + facebook_prediction = facebook_output[1] + + assert torch.equal(input=mask_prediction, other=facebook_masks) + assert torch.equal(input=iou_prediction, other=facebook_prediction) + + +@torch.no_grad() +def test_predictor( + facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, prompt: SAMPrompt +) -> None: + predictor = facebook_sam_h_predictor + predictor.set_image(np.array(truck)) + facebook_masks, facebook_scores, _ = predictor.predict(**prompt.facebook_predict_kwargs()) # type: ignore + + assert len(facebook_masks) == 3 + + masks, scores, _ = sam_h.predict(truck, **prompt.__dict__) + masks = masks.squeeze(0) + scores = scores.squeeze(0) + + assert len(masks) == 3 + + for i in range(3): + mask_prediction = masks[i].cpu() + facebook_mask = torch.as_tensor(facebook_masks[i]) + assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05) + assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05) + + +@torch.no_grad() +def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None: + masks_ref, scores_ref, _ = sam_h.predict(truck, **one_prompt.__dict__) + + image_embedding = sam_h.compute_image_embedding(truck) + masks, scores, _ = sam_h.predict(image_embedding, **one_prompt.__dict__) + + assert torch.equal(masks, masks_ref) + assert torch.equal(scores_ref, scores) diff --git a/tests/foundationals/segment_anything/test_sam_ref/README.md b/tests/foundationals/segment_anything/test_sam_ref/README.md new file mode 100644 index 0000000..0ad3a32 --- /dev/null +++ b/tests/foundationals/segment_anything/test_sam_ref/README.md @@ -0,0 +1,3 @@ +# Note about this data + +`truck.jpg` is one of the [images](https://github.com/facebookresearch/segment-anything/tree/main/notebooks/images) used in the official [segment-anything notebooks](https://github.com/facebookresearch/segment-anything/tree/main/notebooks). diff --git a/tests/foundationals/segment_anything/test_sam_ref/truck.jpg b/tests/foundationals/segment_anything/test_sam_ref/truck.jpg new file mode 100644 index 0000000..6b98688 Binary files /dev/null and b/tests/foundationals/segment_anything/test_sam_ref/truck.jpg differ diff --git a/tests/foundationals/segment_anything/utils.py b/tests/foundationals/segment_anything/utils.py new file mode 100644 index 0000000..37085e2 --- /dev/null +++ b/tests/foundationals/segment_anything/utils.py @@ -0,0 +1,107 @@ +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, TypedDict + +from jaxtyping import Bool +from torch import Tensor, nn +import numpy as np +import numpy.typing as npt +import torch + +NDArrayUInt8 = npt.NDArray[np.uint8] +NDArray = npt.NDArray[Any] + + +class SAMInput(TypedDict): + image: Tensor + original_size: tuple[int, int] + point_coords: Tensor | None + point_labels: Tensor | None + boxes: Tensor | None + mask_inputs: Tensor | None + + +class SAMOutput(TypedDict): + masks: Tensor + iou_predictions: Tensor + low_res_logits: Tensor + + +class FacebookSAM(nn.Module): + image_encoder: nn.Module + prompt_encoder: nn.Module + mask_decoder: nn.Module + + def __call__(self, batched_input: list[SAMInput], multimask_output: bool) -> list[SAMOutput]: ... + + @property + def device(self) -> Any: ... + + +class FacebookSAMPredictor: + model: FacebookSAM + + def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ... + + def predict( + self, + point_coords: NDArray | None = None, + point_labels: NDArray | None = None, + box: NDArray | None = None, + mask_input: NDArray | None = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> tuple[NDArray, NDArray, NDArray]: ... + + +@dataclass +class SAMPrompt: + foreground_points: Sequence[tuple[float, float]] | None = None + background_points: Sequence[tuple[float, float]] | None = None + box_points: Sequence[Sequence[tuple[float, float]]] | None = None + # TODO: support masks + # masks: Sequence[Image.Image] | None = None + + def facebook_predict_kwargs(self) -> dict[str, NDArray]: + prompt: dict[str, NDArray] = {} + # Note: the order matters since `points_to_tensor` processes points that way (background -> foreground -> etc) + if self.background_points: + prompt["point_coords"] = np.array(self.background_points) + prompt["point_labels"] = np.array([0] * len(self.background_points)) + if self.foreground_points: + coords = np.array(self.foreground_points) + prompt["point_coords"] = ( + coords if "point_coords" not in prompt else np.concatenate((prompt["point_coords"], coords)) + ) + labels = np.array([1] * len(self.foreground_points)) + prompt["point_labels"] = ( + labels if "point_labels" not in prompt else np.concatenate((prompt["point_labels"], labels)) + ) + if self.box_points: + prompt["box"] = np.array([coord for batch in self.box_points for xy in batch for coord in xy]).reshape( + len(self.box_points), 4 + ) + return prompt + + def facebook_prompt_encoder_kwargs(self, device: torch.device | None = None): + prompt = self.facebook_predict_kwargs() + coords: Tensor | None = None + labels: Tensor | None = None + boxes: Tensor | None = None + if "point_coords" in prompt: + coords = torch.as_tensor(prompt["point_coords"], dtype=torch.float, device=device).unsqueeze(0) + if "point_labels" in prompt: + labels = torch.as_tensor(prompt["point_labels"], dtype=torch.int, device=device).unsqueeze(0) + if "box" in prompt: + boxes = torch.as_tensor(prompt["box"], dtype=torch.float, device=device).unsqueeze(0) + points = (coords, labels) if coords is not None else None + # TODO: support masks + return {"points": points, "boxes": boxes, "masks": None} + + +def intersection_over_union( + input_mask: Bool[Tensor, "height width"], other_mask: Bool[Tensor, "height width"] +) -> float: + inter = (input_mask & other_mask).sum(dtype=torch.float32).item() + union = (input_mask | other_mask).sum(dtype=torch.float32).item() + return inter / union if union > 0 else 1.0