refiners/tests/foundationals/segment_anything/utils.py
2024-03-23 13:04:00 +00:00

133 lines
4.6 KiB
Python

from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, TypedDict
import numpy as np
import numpy.typing as npt
import torch
from jaxtyping import Bool
from torch import Tensor, nn
NDArrayUInt8 = npt.NDArray[np.uint8]
NDArray = npt.NDArray[Any]
class SAMInput(TypedDict):
image: Tensor
original_size: tuple[int, int]
point_coords: Tensor | None
point_labels: Tensor | None
boxes: Tensor | None
mask_inputs: Tensor | None
class SAMOutput(TypedDict):
masks: Tensor
iou_predictions: Tensor
low_res_logits: Tensor
class FacebookSAM(nn.Module):
image_encoder: nn.Module
prompt_encoder: nn.Module
mask_decoder: nn.Module
def __call__(self, batched_input: list[SAMInput], multimask_output: bool) -> list[SAMOutput]: ...
@property
def device(self) -> Any: ...
class FacebookSAMPredictor:
model: FacebookSAM
features: Tensor
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...
def predict(
self,
point_coords: NDArray | None = None,
point_labels: NDArray | None = None,
box: NDArray | None = None,
mask_input: NDArray | None = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> tuple[NDArray, NDArray, NDArray]: ...
class FacebookSAMPredictorHQ:
model: FacebookSAM
features: Tensor
interm_features: Tensor
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...
def predict(
self,
point_coords: NDArray | None = None,
point_labels: NDArray | None = None,
box: NDArray | None = None,
mask_input: NDArray | None = None,
multimask_output: bool = True,
return_logits: bool = False,
hq_token_only: bool = False,
) -> tuple[NDArray, NDArray, NDArray]: ...
@dataclass
class SAMPrompt:
foreground_points: Sequence[tuple[float, float]] | None = None
background_points: Sequence[tuple[float, float]] | None = None
box_points: Sequence[Sequence[tuple[float, float]]] | None = None
low_res_mask: Tensor | None = None
def facebook_predict_kwargs(self) -> dict[str, NDArray]:
prompt: dict[str, NDArray] = {}
# Note: the order matters since `points_to_tensor` processes points that way (background -> foreground -> etc)
if self.background_points:
prompt["point_coords"] = np.array(self.background_points)
prompt["point_labels"] = np.array([0] * len(self.background_points))
if self.foreground_points:
coords = np.array(self.foreground_points)
prompt["point_coords"] = (
coords if "point_coords" not in prompt else np.concatenate((prompt["point_coords"], coords))
)
labels = np.array([1] * len(self.foreground_points))
prompt["point_labels"] = (
labels if "point_labels" not in prompt else np.concatenate((prompt["point_labels"], labels))
)
if self.box_points:
prompt["box"] = np.array([coord for batch in self.box_points for xy in batch for coord in xy]).reshape(
len(self.box_points), 4
)
if self.low_res_mask is not None:
prompt["mask_input"] = np.array(self.low_res_mask)
return prompt
def facebook_prompt_encoder_kwargs(
self, device: torch.device | None = None
) -> dict[str, Tensor | tuple[Tensor, Tensor | None] | None]:
prompt = self.facebook_predict_kwargs()
coords: Tensor | None = None
labels: Tensor | None = None
boxes: Tensor | None = None
masks: Tensor | None = None
if "point_coords" in prompt:
coords = torch.as_tensor(prompt["point_coords"], dtype=torch.float, device=device).unsqueeze(0)
if "point_labels" in prompt:
labels = torch.as_tensor(prompt["point_labels"], dtype=torch.int, device=device).unsqueeze(0)
if "box" in prompt:
boxes = torch.as_tensor(prompt["box"], dtype=torch.float, device=device).unsqueeze(0)
points = (coords, labels) if coords is not None else None
if "mask_input" in prompt:
masks = torch.as_tensor(prompt["mask_input"], dtype=torch.float, device=device).unsqueeze(0)
return {"points": points, "boxes": boxes, "masks": masks}
def intersection_over_union(
input_mask: Bool[Tensor, "height width"], other_mask: Bool[Tensor, "height width"]
) -> float:
inter = (input_mask & other_mask).sum(dtype=torch.float32).item()
union = (input_mask | other_mask).sum(dtype=torch.float32).item()
return inter / union if union > 0 else 1.0