(doc/foundationals) add SegmentAnything, related docstrings

This commit is contained in:
Laurent 2024-02-02 13:51:43 +00:00 committed by Laureηt
parent 7bc5ce35d2
commit 9b2b109897
2 changed files with 95 additions and 0 deletions

View file

@ -0,0 +1,3 @@
from refiners.foundationals.segment_anything.model import SegmentAnything, SegmentAnythingH
__all__ = ["SegmentAnything", "SegmentAnythingH"]

View file

@ -21,6 +21,14 @@ class ImageEmbedding:
class SegmentAnything(fl.Module): class SegmentAnything(fl.Module):
"""SegmentAnything model.
See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643)
Attributes:
mask_threshold (float): 0.0
"""
mask_threshold: float = 0.0 mask_threshold: float = 0.0
def __init__( def __init__(
@ -32,6 +40,16 @@ class SegmentAnything(fl.Module):
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = torch.float32, dtype: DType = torch.float32,
) -> None: ) -> None:
"""Initialize SegmentAnything model.
Args:
image_encoder: The image encoder to use.
point_encoder: The point encoder to use.
mask_encoder: The mask encoder to use.
mask_decoder: The mask decoder to use.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__() super().__init__()
self.device: Device = device if isinstance(device, Device) else Device(device=device) self.device: Device = device if isinstance(device, Device) else Device(device=device)
self.dtype = dtype self.dtype = dtype
@ -42,6 +60,14 @@ class SegmentAnything(fl.Module):
@no_grad() @no_grad()
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding: def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
"""Compute the emmbedding of an image.
Args:
image: The image to compute the embedding of.
Returns:
The computed image embedding.
"""
original_size = (image.height, image.width) original_size = (image.height, image.width)
target_size = self.compute_target_size(original_size) target_size = self.compute_target_size(original_size)
return ImageEmbedding( return ImageEmbedding(
@ -59,6 +85,21 @@ class SegmentAnything(fl.Module):
low_res_mask: Float[Tensor, "1 1 256 256"] | None = None, low_res_mask: Float[Tensor, "1 1 256 256"] | None = None,
binarize: bool = True, binarize: bool = True,
) -> tuple[Tensor, Tensor, Tensor]: ) -> tuple[Tensor, Tensor, Tensor]:
"""Predict the masks of the input image.
Args:
input: The input image or its embedding.
foreground_points: The points of the foreground.
background_points: The points of the background.
box_points: The points of the box.
low_res_mask: The low resolution mask.
binarize: Whether to binarize the masks.
Returns:
The predicted masks.
The IOU prediction.
The low resolution masks.
"""
if isinstance(input, ImageEmbedding): if isinstance(input, ImageEmbedding):
original_size = input.original_image_size original_size = input.original_image_size
target_size = self.compute_target_size(original_size) target_size = self.compute_target_size(original_size)
@ -107,11 +148,21 @@ class SegmentAnything(fl.Module):
@property @property
def image_size(self) -> int: def image_size(self) -> int:
"""The image size."""
w, h = self.image_encoder.image_size w, h = self.image_encoder.image_size
assert w == h assert w == h
return w return w
def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]: def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]:
"""Compute the target size for a given size.
Args:
size: The size of the image.
Returns:
The target height.
The target width.
"""
oldh, oldw = size oldh, oldw = size
scale = self.image_size * 1.0 / max(oldh, oldw) scale = self.image_size * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale newh, neww = oldh * scale, oldw * scale
@ -120,6 +171,15 @@ class SegmentAnything(fl.Module):
return (newh, neww) return (newh, neww)
def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor: def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor:
"""Preprocess an image.
Args:
image: The image to preprocess.
target_size: The target size.
Returns:
The preprocessed image.
"""
h, w = target_size h, w = target_size
padh = self.image_size - h padh = self.image_size - h
padw = self.image_size - w padw = self.image_size - w
@ -133,11 +193,31 @@ class SegmentAnything(fl.Module):
) )
def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor: def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
"""Normalize the coordinates.
Args:
coordinates: The coordinates to normalize.
target_size: The target size.
original_size: The original size.
Returns:
The normalized coordinates.
"""
coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size
coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size
return coordinates return coordinates
def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor: def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
"""Postprocess the masks.
Args:
masks: The masks to postprocess.
target_size: The target size.
original_size: The original size.
Returns:
The postprocessed masks.
"""
masks = interpolate(masks, factor=torch.Size((self.image_size, self.image_size)), mode="bilinear") masks = interpolate(masks, factor=torch.Size((self.image_size, self.image_size)), mode="bilinear")
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
masks = interpolate(masks, factor=torch.Size(original_size), mode="bilinear") masks = interpolate(masks, factor=torch.Size(original_size), mode="bilinear")
@ -145,6 +225,8 @@ class SegmentAnything(fl.Module):
class SegmentAnythingH(SegmentAnything): class SegmentAnythingH(SegmentAnything):
"""SegmentAnything huge model."""
def __init__( def __init__(
self, self,
image_encoder: SAMViTH | None = None, image_encoder: SAMViTH | None = None,
@ -154,6 +236,16 @@ class SegmentAnythingH(SegmentAnything):
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = torch.float32, dtype: DType = torch.float32,
) -> None: ) -> None:
"""Initialize SegmentAnything huge model.
Args:
image_encoder: The image encoder to use.
point_encoder: The point encoder to use.
mask_encoder: The mask encoder to use.
mask_decoder: The mask decoder to use.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
image_encoder = image_encoder or SAMViTH() image_encoder = image_encoder or SAMViTH()
point_encoder = point_encoder or PointEncoder() point_encoder = point_encoder or PointEncoder()
mask_encoder = mask_encoder or MaskEncoder() mask_encoder = mask_encoder or MaskEncoder()