From 9b2b109897127d282b00c7d568407c228b5349da Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 2 Feb 2024 13:51:43 +0000 Subject: [PATCH] (doc/foundationals) add `SegmentAnything`, related docstrings --- .../segment_anything/__init__.py | 3 + .../foundationals/segment_anything/model.py | 92 +++++++++++++++++++ 2 files changed, 95 insertions(+) diff --git a/src/refiners/foundationals/segment_anything/__init__.py b/src/refiners/foundationals/segment_anything/__init__.py index e69de29..5f1d54d 100644 --- a/src/refiners/foundationals/segment_anything/__init__.py +++ b/src/refiners/foundationals/segment_anything/__init__.py @@ -0,0 +1,3 @@ +from refiners.foundationals.segment_anything.model import SegmentAnything, SegmentAnythingH + +__all__ = ["SegmentAnything", "SegmentAnythingH"] diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py index 905c4b6..8cc6142 100644 --- a/src/refiners/foundationals/segment_anything/model.py +++ b/src/refiners/foundationals/segment_anything/model.py @@ -21,6 +21,14 @@ class ImageEmbedding: class SegmentAnything(fl.Module): + """SegmentAnything model. + + See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643) + + Attributes: + mask_threshold (float): 0.0 + """ + mask_threshold: float = 0.0 def __init__( @@ -32,6 +40,16 @@ class SegmentAnything(fl.Module): device: Device | str = "cpu", dtype: DType = torch.float32, ) -> None: + """Initialize SegmentAnything model. + + Args: + image_encoder: The image encoder to use. + point_encoder: The point encoder to use. + mask_encoder: The mask encoder to use. + mask_decoder: The mask decoder to use. + device: The PyTorch device to use. + dtype: The PyTorch data type to use. + """ super().__init__() self.device: Device = device if isinstance(device, Device) else Device(device=device) self.dtype = dtype @@ -42,6 +60,14 @@ class SegmentAnything(fl.Module): @no_grad() def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding: + """Compute the emmbedding of an image. + + Args: + image: The image to compute the embedding of. + + Returns: + The computed image embedding. + """ original_size = (image.height, image.width) target_size = self.compute_target_size(original_size) return ImageEmbedding( @@ -59,6 +85,21 @@ class SegmentAnything(fl.Module): low_res_mask: Float[Tensor, "1 1 256 256"] | None = None, binarize: bool = True, ) -> tuple[Tensor, Tensor, Tensor]: + """Predict the masks of the input image. + + Args: + input: The input image or its embedding. + foreground_points: The points of the foreground. + background_points: The points of the background. + box_points: The points of the box. + low_res_mask: The low resolution mask. + binarize: Whether to binarize the masks. + + Returns: + The predicted masks. + The IOU prediction. + The low resolution masks. + """ if isinstance(input, ImageEmbedding): original_size = input.original_image_size target_size = self.compute_target_size(original_size) @@ -107,11 +148,21 @@ class SegmentAnything(fl.Module): @property def image_size(self) -> int: + """The image size.""" w, h = self.image_encoder.image_size assert w == h return w def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]: + """Compute the target size for a given size. + + Args: + size: The size of the image. + + Returns: + The target height. + The target width. + """ oldh, oldw = size scale = self.image_size * 1.0 / max(oldh, oldw) newh, neww = oldh * scale, oldw * scale @@ -120,6 +171,15 @@ class SegmentAnything(fl.Module): return (newh, neww) def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor: + """Preprocess an image. + + Args: + image: The image to preprocess. + target_size: The target size. + + Returns: + The preprocessed image. + """ h, w = target_size padh = self.image_size - h padw = self.image_size - w @@ -133,11 +193,31 @@ class SegmentAnything(fl.Module): ) def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor: + """Normalize the coordinates. + + Args: + coordinates: The coordinates to normalize. + target_size: The target size. + original_size: The original size. + + Returns: + The normalized coordinates. + """ coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size return coordinates def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor: + """Postprocess the masks. + + Args: + masks: The masks to postprocess. + target_size: The target size. + original_size: The original size. + + Returns: + The postprocessed masks. + """ masks = interpolate(masks, factor=torch.Size((self.image_size, self.image_size)), mode="bilinear") masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time masks = interpolate(masks, factor=torch.Size(original_size), mode="bilinear") @@ -145,6 +225,8 @@ class SegmentAnything(fl.Module): class SegmentAnythingH(SegmentAnything): + """SegmentAnything huge model.""" + def __init__( self, image_encoder: SAMViTH | None = None, @@ -154,6 +236,16 @@ class SegmentAnythingH(SegmentAnything): device: Device | str = "cpu", dtype: DType = torch.float32, ) -> None: + """Initialize SegmentAnything huge model. + + Args: + image_encoder: The image encoder to use. + point_encoder: The point encoder to use. + mask_encoder: The mask encoder to use. + mask_decoder: The mask decoder to use. + device: The PyTorch device to use. + dtype: The PyTorch data type to use. + """ image_encoder = image_encoder or SAMViTH() point_encoder = point_encoder or PointEncoder() mask_encoder = mask_encoder or MaskEncoder()