mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
SAM: expose sizing helpers
This commit is contained in:
parent
06ff2f0a5f
commit
0ac290f67d
|
@ -1,17 +1,21 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from PIL import Image
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.utils import interpolate, no_grad, normalize, pad
|
||||
from refiners.fluxion.utils import no_grad
|
||||
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
|
||||
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
||||
from refiners.foundationals.segment_anything.utils import (
|
||||
normalize_coordinates,
|
||||
postprocess_masks,
|
||||
preprocess_image,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -85,9 +89,8 @@ class SegmentAnything(fl.Chain):
|
|||
The computed image embedding.
|
||||
"""
|
||||
original_size = (image.height, image.width)
|
||||
target_size = self.compute_target_size(original_size)
|
||||
return ImageEmbedding(
|
||||
features=self.image_encoder(self.preprocess_image(image=image, target_size=target_size)),
|
||||
features=self.image_encoder(self.preprocess_image(image)),
|
||||
original_image_size=original_size,
|
||||
)
|
||||
|
||||
|
@ -118,12 +121,10 @@ class SegmentAnything(fl.Chain):
|
|||
"""
|
||||
if isinstance(input, ImageEmbedding):
|
||||
original_size = input.original_image_size
|
||||
target_size = self.compute_target_size(original_size)
|
||||
image_embedding = input.features
|
||||
else:
|
||||
original_size = (input.height, input.width)
|
||||
target_size = self.compute_target_size(original_size)
|
||||
image_embedding = self.image_encoder(self.preprocess_image(image=input, target_size=target_size))
|
||||
image_embedding = self.image_encoder(self.preprocess_image(input))
|
||||
|
||||
coordinates, type_mask = self.point_encoder.points_to_tensor(
|
||||
foreground_points=foreground_points,
|
||||
|
@ -139,9 +140,7 @@ class SegmentAnything(fl.Chain):
|
|||
image_embedding_size=self.image_encoder.image_embedding_size
|
||||
)
|
||||
|
||||
point_embedding = self.point_encoder(
|
||||
self.normalize(coordinates, target_size=target_size, original_size=original_size)
|
||||
)
|
||||
point_embedding = self.point_encoder(self.normalize(coordinates, original_size=original_size))
|
||||
dense_positional_embedding = self.point_encoder.get_dense_positional_embedding(
|
||||
image_embedding_size=self.image_encoder.image_embedding_size
|
||||
)
|
||||
|
@ -153,9 +152,7 @@ class SegmentAnything(fl.Chain):
|
|||
|
||||
low_res_masks, iou_predictions = self.mask_decoder()
|
||||
|
||||
high_res_masks = self.postprocess_masks(
|
||||
masks=low_res_masks, target_size=target_size, original_size=original_size
|
||||
)
|
||||
high_res_masks = self.postprocess_masks(low_res_masks, original_size)
|
||||
|
||||
if binarize:
|
||||
high_res_masks = high_res_masks > self.mask_threshold
|
||||
|
@ -163,82 +160,43 @@ class SegmentAnything(fl.Chain):
|
|||
return high_res_masks, iou_predictions, low_res_masks
|
||||
|
||||
@property
|
||||
def image_size(self) -> int:
|
||||
"""The image size."""
|
||||
def image_encoder_resolution(self) -> int:
|
||||
"""The resolution of the image encoder."""
|
||||
w, h = self.image_encoder.image_size
|
||||
assert w == h
|
||||
return w
|
||||
|
||||
def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]:
|
||||
"""Compute the target size as expected by the image encoder.
|
||||
|
||||
Args:
|
||||
size: The size of the input image.
|
||||
|
||||
Returns:
|
||||
The target height.
|
||||
The target width.
|
||||
def preprocess_image(self, image: Image.Image) -> Tensor:
|
||||
"""
|
||||
oldh, oldw = size
|
||||
scale = self.image_size * 1.0 / max(oldh, oldw)
|
||||
newh, neww = oldh * scale, oldw * scale
|
||||
neww = int(neww + 0.5)
|
||||
newh = int(newh + 0.5)
|
||||
return (newh, neww)
|
||||
|
||||
def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor:
|
||||
"""Preprocess an image without distorting its aspect ratio.
|
||||
|
||||
See [`preprocess_image`][refiners.foundationals.segment_anything.utils.preprocess_image]
|
||||
Args:
|
||||
image: The image to preprocess.
|
||||
target_size: The target size.
|
||||
|
||||
Returns:
|
||||
The preprocessed image.
|
||||
The preprocessed tensor.
|
||||
"""
|
||||
h, w = target_size
|
||||
padh = self.image_size - h
|
||||
padw = self.image_size - w
|
||||
resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore
|
||||
image_tensor = torch.tensor(
|
||||
np.array(resized).astype(np.float32).transpose(2, 0, 1),
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
).unsqueeze(0)
|
||||
return pad(
|
||||
normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]), (0, padw, 0, padh)
|
||||
)
|
||||
|
||||
def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
|
||||
"""Normalize the coordinates.
|
||||
return preprocess_image(image, self.image_encoder_resolution, self.device, self.dtype)
|
||||
|
||||
def normalize(self, coordinates: Tensor, original_size: tuple[int, int]) -> Tensor:
|
||||
"""
|
||||
See [`normalize_coordinates`][refiners.foundationals.segment_anything.utils.normalize_coordinates]
|
||||
Args:
|
||||
coordinates: The coordinates to normalize.
|
||||
target_size: The target size.
|
||||
original_size: The original size.
|
||||
|
||||
coordinates: a tensor of coordinates.
|
||||
original_size: (h, w) the original size of the image.
|
||||
Returns:
|
||||
The normalized coordinates.
|
||||
The [0,1] normalized coordinates tensor.
|
||||
"""
|
||||
coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size
|
||||
coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size
|
||||
return coordinates
|
||||
|
||||
def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
|
||||
"""Postprocess the masks.
|
||||
return normalize_coordinates(coordinates, original_size, self.image_encoder_resolution)
|
||||
|
||||
def postprocess_masks(self, low_res_masks: Tensor, original_size: tuple[int, int]) -> Tensor:
|
||||
"""
|
||||
See [`postprocess_masks`][refiners.foundationals.segment_anything.utils.postprocess_masks]
|
||||
Args:
|
||||
masks: The masks to postprocess.
|
||||
target_size: The target size.
|
||||
original_size: The original size.
|
||||
|
||||
low_res_masks: a mask tensor of size (N, 1, 256, 256)
|
||||
original_size: (h, w) the original size of the image.
|
||||
Returns:
|
||||
The postprocessed masks.
|
||||
The mask of shape (N, 1, H, W)
|
||||
"""
|
||||
masks = interpolate(masks, size=torch.Size((self.image_size, self.image_size)), mode="bilinear")
|
||||
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
|
||||
masks = interpolate(masks, size=torch.Size(original_size), mode="bilinear")
|
||||
return masks
|
||||
return postprocess_masks(low_res_masks, original_size, self.image_encoder_resolution)
|
||||
|
||||
|
||||
class SegmentAnythingH(SegmentAnything):
|
||||
|
|
132
src/refiners/foundationals/segment_anything/utils.py
Normal file
132
src/refiners/foundationals/segment_anything/utils.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
from torch import Size, Tensor, device as Device, dtype as DType, tensor
|
||||
|
||||
from refiners.fluxion.utils import interpolate, normalize, pad
|
||||
|
||||
|
||||
def compute_scaled_size(size: tuple[int, int], image_encoder_resolution: int) -> tuple[int, int]:
|
||||
"""Compute the scaled size as expected by the image encoder.
|
||||
This computed size keep the ratio of the input image, and scale it to fit inside the square (image_encoder_resolution, image_encoder_resolution) of image encoder.
|
||||
|
||||
Args:
|
||||
size: The size (h, w) of the input image.
|
||||
image_encoder_resolution: Image encoder resolution.
|
||||
|
||||
Returns:
|
||||
The target height.
|
||||
The target width.
|
||||
"""
|
||||
oldh, oldw = size
|
||||
scale = image_encoder_resolution * 1.0 / max(oldh, oldw)
|
||||
newh, neww = oldh * scale, oldw * scale
|
||||
neww = int(neww + 0.5)
|
||||
newh = int(newh + 0.5)
|
||||
return (newh, neww)
|
||||
|
||||
|
||||
def image_to_scaled_tensor(
|
||||
image: Image.Image, scaled_size: tuple[int, int], device: Device | None = None, dtype: DType | None = None
|
||||
) -> Tensor:
|
||||
"""Resize the image to `scaled_size` and convert it to a tensor.
|
||||
|
||||
Args:
|
||||
image: The image.
|
||||
scaled_size: The target size (h, w).
|
||||
device: Tensor device.
|
||||
dtype: Tensor dtype.
|
||||
Returns:
|
||||
a Tensor of shape (1, c, h, w)
|
||||
"""
|
||||
h, w = scaled_size
|
||||
resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore
|
||||
return tensor(
|
||||
np.array(resized).astype(np.float32).transpose(2, 0, 1),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
).unsqueeze(0)
|
||||
|
||||
|
||||
def preprocess_image(
|
||||
image: Image.Image, image_encoder_resolution: int, device: Device | None = None, dtype: DType | None = None
|
||||
) -> Tensor:
|
||||
"""Preprocess an image without distorting its aspect ratio.
|
||||
|
||||
Args:
|
||||
image: The image to preprocess before calling the image encoder.
|
||||
image_encoder_resolution: Image encoder resolution.
|
||||
device: Tensor device (None by default).
|
||||
dtype: Tensor dtype (None by default).
|
||||
|
||||
Returns:
|
||||
The preprocessed image.
|
||||
"""
|
||||
scaled_size = compute_scaled_size((image.height, image.width), image_encoder_resolution)
|
||||
|
||||
image_tensor = image_to_scaled_tensor(image, scaled_size, device, dtype)
|
||||
|
||||
return pad_image_tensor(
|
||||
normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
|
||||
scaled_size,
|
||||
image_encoder_resolution,
|
||||
)
|
||||
|
||||
|
||||
def pad_image_tensor(image_tensor: Tensor, scaled_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:
|
||||
"""Pad an image with zeros to make it square.
|
||||
|
||||
Args:
|
||||
image_tensor: The image tensor to pad.
|
||||
scaled_size: The scaled size (h, w).
|
||||
image_encoder_resolution: Image encoder resolution.
|
||||
|
||||
Returns:
|
||||
The padded image.
|
||||
"""
|
||||
assert len(image_tensor.shape) == 4
|
||||
assert image_tensor.shape[2] <= image_encoder_resolution
|
||||
assert image_tensor.shape[3] <= image_encoder_resolution
|
||||
|
||||
h, w = scaled_size
|
||||
padh = image_encoder_resolution - h
|
||||
padw = image_encoder_resolution - w
|
||||
return pad(image_tensor, (0, padw, 0, padh))
|
||||
|
||||
|
||||
def postprocess_masks(low_res_masks: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:
|
||||
"""Postprocess the masks to fit the original image size and remove zero-padding (if any).
|
||||
|
||||
Args:
|
||||
low_res_masks: The masks to postprocess.
|
||||
original_size: The original size (h, w).
|
||||
image_encoder_resolution: Image encoder resolution.
|
||||
|
||||
Returns:
|
||||
The postprocessed masks.
|
||||
"""
|
||||
scaled_size = compute_scaled_size(original_size, image_encoder_resolution)
|
||||
masks = interpolate(low_res_masks, size=Size((image_encoder_resolution, image_encoder_resolution)), mode="bilinear")
|
||||
masks = masks[..., : scaled_size[0], : scaled_size[1]] # remove padding added at `preprocess_image` time
|
||||
masks = interpolate(masks, size=Size(original_size), mode="bilinear")
|
||||
return masks
|
||||
|
||||
|
||||
def normalize_coordinates(coordinates: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:
|
||||
"""Normalize the coordinates in the [0,1] range
|
||||
|
||||
Args:
|
||||
coordinates: The coordinates to normalize.
|
||||
original_size: The original image size.
|
||||
image_encoder_resolution: Image encoder resolution.
|
||||
|
||||
Returns:
|
||||
The normalized coordinates.
|
||||
"""
|
||||
scaled_size = compute_scaled_size(original_size, image_encoder_resolution)
|
||||
coordinates[:, :, 0] = (
|
||||
(coordinates[:, :, 0] * (scaled_size[1] / original_size[1])) + 0.5
|
||||
) / image_encoder_resolution
|
||||
coordinates[:, :, 1] = (
|
||||
(coordinates[:, :, 1] * (scaled_size[0] / original_size[0])) + 0.5
|
||||
) / image_encoder_resolution
|
||||
return coordinates
|
47
tests/foundationals/segment_anything/test_utils.py
Normal file
47
tests/foundationals/segment_anything/test_utils.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from refiners.foundationals.segment_anything.utils import (
|
||||
compute_scaled_size,
|
||||
image_to_scaled_tensor,
|
||||
pad_image_tensor,
|
||||
preprocess_image,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_encoder_resolution() -> int:
|
||||
return 1024
|
||||
|
||||
|
||||
def test_compute_scaled_size(image_encoder_resolution: int) -> None:
|
||||
w, h = (1536, 768)
|
||||
scaled_size = compute_scaled_size((h, w), image_encoder_resolution)
|
||||
|
||||
assert scaled_size == (512, 1024)
|
||||
|
||||
|
||||
def test_image_to_scaled_tensor() -> None:
|
||||
image = Image.new("RGB", (1536, 768))
|
||||
tensor = image_to_scaled_tensor(image, (512, 1024))
|
||||
|
||||
assert tensor.shape == (1, 3, 512, 1024)
|
||||
|
||||
|
||||
def test_preprocess_image(image_encoder_resolution: int) -> None:
|
||||
image = Image.new("RGB", (1536, 768))
|
||||
preprocessed = preprocess_image(image, image_encoder_resolution)
|
||||
|
||||
assert preprocessed.shape == (1, 3, 1024, 1024)
|
||||
|
||||
|
||||
def test_pad_image_tensor(image_encoder_resolution: int) -> None:
|
||||
w, h = (1536, 768)
|
||||
image = Image.new("RGB", (w, h), color="white")
|
||||
scaled_size = compute_scaled_size((h, w), image_encoder_resolution)
|
||||
scaled_image_tensor = image_to_scaled_tensor(image, scaled_size)
|
||||
padded_image_tensor = pad_image_tensor(scaled_image_tensor, scaled_size, image_encoder_resolution)
|
||||
|
||||
assert padded_image_tensor.shape == (1, 3, 1024, 1024)
|
||||
assert torch.all(padded_image_tensor[:, :, 512:, :] == 0)
|
Loading…
Reference in a new issue