mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 14:48:45 +00:00
(doc/foundationals) add SegmentAnything
, related docstrings
This commit is contained in:
parent
7bc5ce35d2
commit
9b2b109897
|
@ -0,0 +1,3 @@
|
||||||
|
from refiners.foundationals.segment_anything.model import SegmentAnything, SegmentAnythingH
|
||||||
|
|
||||||
|
__all__ = ["SegmentAnything", "SegmentAnythingH"]
|
|
@ -21,6 +21,14 @@ class ImageEmbedding:
|
||||||
|
|
||||||
|
|
||||||
class SegmentAnything(fl.Module):
|
class SegmentAnything(fl.Module):
|
||||||
|
"""SegmentAnything model.
|
||||||
|
|
||||||
|
See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
mask_threshold (float): 0.0
|
||||||
|
"""
|
||||||
|
|
||||||
mask_threshold: float = 0.0
|
mask_threshold: float = 0.0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -32,6 +40,16 @@ class SegmentAnything(fl.Module):
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = torch.float32,
|
dtype: DType = torch.float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize SegmentAnything model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_encoder: The image encoder to use.
|
||||||
|
point_encoder: The point encoder to use.
|
||||||
|
mask_encoder: The mask encoder to use.
|
||||||
|
mask_decoder: The mask decoder to use.
|
||||||
|
device: The PyTorch device to use.
|
||||||
|
dtype: The PyTorch data type to use.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device: Device = device if isinstance(device, Device) else Device(device=device)
|
self.device: Device = device if isinstance(device, Device) else Device(device=device)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
@ -42,6 +60,14 @@ class SegmentAnything(fl.Module):
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
|
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
|
||||||
|
"""Compute the emmbedding of an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: The image to compute the embedding of.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The computed image embedding.
|
||||||
|
"""
|
||||||
original_size = (image.height, image.width)
|
original_size = (image.height, image.width)
|
||||||
target_size = self.compute_target_size(original_size)
|
target_size = self.compute_target_size(original_size)
|
||||||
return ImageEmbedding(
|
return ImageEmbedding(
|
||||||
|
@ -59,6 +85,21 @@ class SegmentAnything(fl.Module):
|
||||||
low_res_mask: Float[Tensor, "1 1 256 256"] | None = None,
|
low_res_mask: Float[Tensor, "1 1 256 256"] | None = None,
|
||||||
binarize: bool = True,
|
binarize: bool = True,
|
||||||
) -> tuple[Tensor, Tensor, Tensor]:
|
) -> tuple[Tensor, Tensor, Tensor]:
|
||||||
|
"""Predict the masks of the input image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: The input image or its embedding.
|
||||||
|
foreground_points: The points of the foreground.
|
||||||
|
background_points: The points of the background.
|
||||||
|
box_points: The points of the box.
|
||||||
|
low_res_mask: The low resolution mask.
|
||||||
|
binarize: Whether to binarize the masks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The predicted masks.
|
||||||
|
The IOU prediction.
|
||||||
|
The low resolution masks.
|
||||||
|
"""
|
||||||
if isinstance(input, ImageEmbedding):
|
if isinstance(input, ImageEmbedding):
|
||||||
original_size = input.original_image_size
|
original_size = input.original_image_size
|
||||||
target_size = self.compute_target_size(original_size)
|
target_size = self.compute_target_size(original_size)
|
||||||
|
@ -107,11 +148,21 @@ class SegmentAnything(fl.Module):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_size(self) -> int:
|
def image_size(self) -> int:
|
||||||
|
"""The image size."""
|
||||||
w, h = self.image_encoder.image_size
|
w, h = self.image_encoder.image_size
|
||||||
assert w == h
|
assert w == h
|
||||||
return w
|
return w
|
||||||
|
|
||||||
def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]:
|
def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]:
|
||||||
|
"""Compute the target size for a given size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size: The size of the image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The target height.
|
||||||
|
The target width.
|
||||||
|
"""
|
||||||
oldh, oldw = size
|
oldh, oldw = size
|
||||||
scale = self.image_size * 1.0 / max(oldh, oldw)
|
scale = self.image_size * 1.0 / max(oldh, oldw)
|
||||||
newh, neww = oldh * scale, oldw * scale
|
newh, neww = oldh * scale, oldw * scale
|
||||||
|
@ -120,6 +171,15 @@ class SegmentAnything(fl.Module):
|
||||||
return (newh, neww)
|
return (newh, neww)
|
||||||
|
|
||||||
def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor:
|
def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor:
|
||||||
|
"""Preprocess an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: The image to preprocess.
|
||||||
|
target_size: The target size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The preprocessed image.
|
||||||
|
"""
|
||||||
h, w = target_size
|
h, w = target_size
|
||||||
padh = self.image_size - h
|
padh = self.image_size - h
|
||||||
padw = self.image_size - w
|
padw = self.image_size - w
|
||||||
|
@ -133,11 +193,31 @@ class SegmentAnything(fl.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
|
def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
|
||||||
|
"""Normalize the coordinates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coordinates: The coordinates to normalize.
|
||||||
|
target_size: The target size.
|
||||||
|
original_size: The original size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The normalized coordinates.
|
||||||
|
"""
|
||||||
coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size
|
coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size
|
||||||
coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size
|
coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size
|
||||||
return coordinates
|
return coordinates
|
||||||
|
|
||||||
def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
|
def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
|
||||||
|
"""Postprocess the masks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
masks: The masks to postprocess.
|
||||||
|
target_size: The target size.
|
||||||
|
original_size: The original size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The postprocessed masks.
|
||||||
|
"""
|
||||||
masks = interpolate(masks, factor=torch.Size((self.image_size, self.image_size)), mode="bilinear")
|
masks = interpolate(masks, factor=torch.Size((self.image_size, self.image_size)), mode="bilinear")
|
||||||
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
|
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
|
||||||
masks = interpolate(masks, factor=torch.Size(original_size), mode="bilinear")
|
masks = interpolate(masks, factor=torch.Size(original_size), mode="bilinear")
|
||||||
|
@ -145,6 +225,8 @@ class SegmentAnything(fl.Module):
|
||||||
|
|
||||||
|
|
||||||
class SegmentAnythingH(SegmentAnything):
|
class SegmentAnythingH(SegmentAnything):
|
||||||
|
"""SegmentAnything huge model."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
image_encoder: SAMViTH | None = None,
|
image_encoder: SAMViTH | None = None,
|
||||||
|
@ -154,6 +236,16 @@ class SegmentAnythingH(SegmentAnything):
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = torch.float32,
|
dtype: DType = torch.float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize SegmentAnything huge model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_encoder: The image encoder to use.
|
||||||
|
point_encoder: The point encoder to use.
|
||||||
|
mask_encoder: The mask encoder to use.
|
||||||
|
mask_decoder: The mask decoder to use.
|
||||||
|
device: The PyTorch device to use.
|
||||||
|
dtype: The PyTorch data type to use.
|
||||||
|
"""
|
||||||
image_encoder = image_encoder or SAMViTH()
|
image_encoder = image_encoder or SAMViTH()
|
||||||
point_encoder = point_encoder or PointEncoder()
|
point_encoder = point_encoder or PointEncoder()
|
||||||
mask_encoder = mask_encoder or MaskEncoder()
|
mask_encoder = mask_encoder or MaskEncoder()
|
||||||
|
|
Loading…
Reference in a new issue