mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
(doc/foundationals) add SegmentAnything
, related docstrings
This commit is contained in:
parent
a926696141
commit
f62e71da1c
|
@ -0,0 +1,3 @@
|
|||
from refiners.foundationals.segment_anything.model import SegmentAnything, SegmentAnythingH
|
||||
|
||||
__all__ = ["SegmentAnything", "SegmentAnythingH"]
|
|
@ -21,6 +21,14 @@ class ImageEmbedding:
|
|||
|
||||
|
||||
class SegmentAnything(fl.Module):
|
||||
"""SegmentAnything model.
|
||||
|
||||
See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643)
|
||||
|
||||
Attributes:
|
||||
mask_threshold (float): 0.0
|
||||
"""
|
||||
|
||||
mask_threshold: float = 0.0
|
||||
|
||||
def __init__(
|
||||
|
@ -32,6 +40,16 @@ class SegmentAnything(fl.Module):
|
|||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
"""Initialize SegmentAnything model.
|
||||
|
||||
Args:
|
||||
image_encoder: The image encoder to use.
|
||||
point_encoder: The point encoder to use.
|
||||
mask_encoder: The mask encoder to use.
|
||||
mask_decoder: The mask decoder to use.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
super().__init__()
|
||||
self.device: Device = device if isinstance(device, Device) else Device(device=device)
|
||||
self.dtype = dtype
|
||||
|
@ -42,6 +60,14 @@ class SegmentAnything(fl.Module):
|
|||
|
||||
@no_grad()
|
||||
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
|
||||
"""Compute the emmbedding of an image.
|
||||
|
||||
Args:
|
||||
image: The image to compute the embedding of.
|
||||
|
||||
Returns:
|
||||
The computed image embedding.
|
||||
"""
|
||||
original_size = (image.height, image.width)
|
||||
target_size = self.compute_target_size(original_size)
|
||||
return ImageEmbedding(
|
||||
|
@ -59,6 +85,21 @@ class SegmentAnything(fl.Module):
|
|||
low_res_mask: Float[Tensor, "1 1 256 256"] | None = None,
|
||||
binarize: bool = True,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
"""Predict the masks of the input image.
|
||||
|
||||
Args:
|
||||
input: The input image or its embedding.
|
||||
foreground_points: The points of the foreground.
|
||||
background_points: The points of the background.
|
||||
box_points: The points of the box.
|
||||
low_res_mask: The low resolution mask.
|
||||
binarize: Whether to binarize the masks.
|
||||
|
||||
Returns:
|
||||
The predicted masks.
|
||||
The IOU prediction.
|
||||
The low resolution masks.
|
||||
"""
|
||||
if isinstance(input, ImageEmbedding):
|
||||
original_size = input.original_image_size
|
||||
target_size = self.compute_target_size(original_size)
|
||||
|
@ -107,11 +148,21 @@ class SegmentAnything(fl.Module):
|
|||
|
||||
@property
|
||||
def image_size(self) -> int:
|
||||
"""The image size."""
|
||||
w, h = self.image_encoder.image_size
|
||||
assert w == h
|
||||
return w
|
||||
|
||||
def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]:
|
||||
"""Compute the target size for a given size.
|
||||
|
||||
Args:
|
||||
size: The size of the image.
|
||||
|
||||
Returns:
|
||||
The target height.
|
||||
The target width.
|
||||
"""
|
||||
oldh, oldw = size
|
||||
scale = self.image_size * 1.0 / max(oldh, oldw)
|
||||
newh, neww = oldh * scale, oldw * scale
|
||||
|
@ -120,6 +171,15 @@ class SegmentAnything(fl.Module):
|
|||
return (newh, neww)
|
||||
|
||||
def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor:
|
||||
"""Preprocess an image.
|
||||
|
||||
Args:
|
||||
image: The image to preprocess.
|
||||
target_size: The target size.
|
||||
|
||||
Returns:
|
||||
The preprocessed image.
|
||||
"""
|
||||
h, w = target_size
|
||||
padh = self.image_size - h
|
||||
padw = self.image_size - w
|
||||
|
@ -133,11 +193,31 @@ class SegmentAnything(fl.Module):
|
|||
)
|
||||
|
||||
def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
|
||||
"""Normalize the coordinates.
|
||||
|
||||
Args:
|
||||
coordinates: The coordinates to normalize.
|
||||
target_size: The target size.
|
||||
original_size: The original size.
|
||||
|
||||
Returns:
|
||||
The normalized coordinates.
|
||||
"""
|
||||
coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size
|
||||
coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size
|
||||
return coordinates
|
||||
|
||||
def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
|
||||
"""Postprocess the masks.
|
||||
|
||||
Args:
|
||||
masks: The masks to postprocess.
|
||||
target_size: The target size.
|
||||
original_size: The original size.
|
||||
|
||||
Returns:
|
||||
The postprocessed masks.
|
||||
"""
|
||||
masks = interpolate(masks, factor=torch.Size((self.image_size, self.image_size)), mode="bilinear")
|
||||
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
|
||||
masks = interpolate(masks, factor=torch.Size(original_size), mode="bilinear")
|
||||
|
@ -145,6 +225,8 @@ class SegmentAnything(fl.Module):
|
|||
|
||||
|
||||
class SegmentAnythingH(SegmentAnything):
|
||||
"""SegmentAnything huge model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_encoder: SAMViTH | None = None,
|
||||
|
@ -154,6 +236,16 @@ class SegmentAnythingH(SegmentAnything):
|
|||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
"""Initialize SegmentAnything huge model.
|
||||
|
||||
Args:
|
||||
image_encoder: The image encoder to use.
|
||||
point_encoder: The point encoder to use.
|
||||
mask_encoder: The mask encoder to use.
|
||||
mask_decoder: The mask decoder to use.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
image_encoder = image_encoder or SAMViTH()
|
||||
point_encoder = point_encoder or PointEncoder()
|
||||
mask_encoder = mask_encoder or MaskEncoder()
|
||||
|
|
Loading…
Reference in a new issue