SAM: expose sizing helpers

This commit is contained in:
Pierre Colle 2024-04-11 13:09:58 +00:00 committed by Colle
parent 06ff2f0a5f
commit 0ac290f67d
3 changed files with 209 additions and 72 deletions

View file

@ -1,17 +1,21 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Sequence from typing import Sequence
import numpy as np
import torch import torch
from jaxtyping import Float from jaxtyping import Float
from PIL import Image from PIL import Image
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.utils import interpolate, no_grad, normalize, pad from refiners.fluxion.utils import no_grad
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
from refiners.foundationals.segment_anything.utils import (
normalize_coordinates,
postprocess_masks,
preprocess_image,
)
@dataclass @dataclass
@ -85,9 +89,8 @@ class SegmentAnything(fl.Chain):
The computed image embedding. The computed image embedding.
""" """
original_size = (image.height, image.width) original_size = (image.height, image.width)
target_size = self.compute_target_size(original_size)
return ImageEmbedding( return ImageEmbedding(
features=self.image_encoder(self.preprocess_image(image=image, target_size=target_size)), features=self.image_encoder(self.preprocess_image(image)),
original_image_size=original_size, original_image_size=original_size,
) )
@ -118,12 +121,10 @@ class SegmentAnything(fl.Chain):
""" """
if isinstance(input, ImageEmbedding): if isinstance(input, ImageEmbedding):
original_size = input.original_image_size original_size = input.original_image_size
target_size = self.compute_target_size(original_size)
image_embedding = input.features image_embedding = input.features
else: else:
original_size = (input.height, input.width) original_size = (input.height, input.width)
target_size = self.compute_target_size(original_size) image_embedding = self.image_encoder(self.preprocess_image(input))
image_embedding = self.image_encoder(self.preprocess_image(image=input, target_size=target_size))
coordinates, type_mask = self.point_encoder.points_to_tensor( coordinates, type_mask = self.point_encoder.points_to_tensor(
foreground_points=foreground_points, foreground_points=foreground_points,
@ -139,9 +140,7 @@ class SegmentAnything(fl.Chain):
image_embedding_size=self.image_encoder.image_embedding_size image_embedding_size=self.image_encoder.image_embedding_size
) )
point_embedding = self.point_encoder( point_embedding = self.point_encoder(self.normalize(coordinates, original_size=original_size))
self.normalize(coordinates, target_size=target_size, original_size=original_size)
)
dense_positional_embedding = self.point_encoder.get_dense_positional_embedding( dense_positional_embedding = self.point_encoder.get_dense_positional_embedding(
image_embedding_size=self.image_encoder.image_embedding_size image_embedding_size=self.image_encoder.image_embedding_size
) )
@ -153,9 +152,7 @@ class SegmentAnything(fl.Chain):
low_res_masks, iou_predictions = self.mask_decoder() low_res_masks, iou_predictions = self.mask_decoder()
high_res_masks = self.postprocess_masks( high_res_masks = self.postprocess_masks(low_res_masks, original_size)
masks=low_res_masks, target_size=target_size, original_size=original_size
)
if binarize: if binarize:
high_res_masks = high_res_masks > self.mask_threshold high_res_masks = high_res_masks > self.mask_threshold
@ -163,82 +160,43 @@ class SegmentAnything(fl.Chain):
return high_res_masks, iou_predictions, low_res_masks return high_res_masks, iou_predictions, low_res_masks
@property @property
def image_size(self) -> int: def image_encoder_resolution(self) -> int:
"""The image size.""" """The resolution of the image encoder."""
w, h = self.image_encoder.image_size w, h = self.image_encoder.image_size
assert w == h assert w == h
return w return w
def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]: def preprocess_image(self, image: Image.Image) -> Tensor:
"""Compute the target size as expected by the image encoder.
Args:
size: The size of the input image.
Returns:
The target height.
The target width.
""" """
oldh, oldw = size See [`preprocess_image`][refiners.foundationals.segment_anything.utils.preprocess_image]
scale = self.image_size * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)
def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor:
"""Preprocess an image without distorting its aspect ratio.
Args: Args:
image: The image to preprocess. image: The image to preprocess.
target_size: The target size.
Returns: Returns:
The preprocessed image. The preprocessed tensor.
""" """
h, w = target_size return preprocess_image(image, self.image_encoder_resolution, self.device, self.dtype)
padh = self.image_size - h
padw = self.image_size - w
resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore
image_tensor = torch.tensor(
np.array(resized).astype(np.float32).transpose(2, 0, 1),
device=self.device,
dtype=self.dtype,
).unsqueeze(0)
return pad(
normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]), (0, padw, 0, padh)
)
def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
"""Normalize the coordinates.
def normalize(self, coordinates: Tensor, original_size: tuple[int, int]) -> Tensor:
"""
See [`normalize_coordinates`][refiners.foundationals.segment_anything.utils.normalize_coordinates]
Args: Args:
coordinates: The coordinates to normalize. coordinates: a tensor of coordinates.
target_size: The target size. original_size: (h, w) the original size of the image.
original_size: The original size.
Returns: Returns:
The normalized coordinates. The [0,1] normalized coordinates tensor.
""" """
coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size return normalize_coordinates(coordinates, original_size, self.image_encoder_resolution)
coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size
return coordinates
def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
"""Postprocess the masks.
def postprocess_masks(self, low_res_masks: Tensor, original_size: tuple[int, int]) -> Tensor:
"""
See [`postprocess_masks`][refiners.foundationals.segment_anything.utils.postprocess_masks]
Args: Args:
masks: The masks to postprocess. low_res_masks: a mask tensor of size (N, 1, 256, 256)
target_size: The target size. original_size: (h, w) the original size of the image.
original_size: The original size.
Returns: Returns:
The postprocessed masks. The mask of shape (N, 1, H, W)
""" """
masks = interpolate(masks, size=torch.Size((self.image_size, self.image_size)), mode="bilinear") return postprocess_masks(low_res_masks, original_size, self.image_encoder_resolution)
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
masks = interpolate(masks, size=torch.Size(original_size), mode="bilinear")
return masks
class SegmentAnythingH(SegmentAnything): class SegmentAnythingH(SegmentAnything):

View file

@ -0,0 +1,132 @@
import numpy as np
from PIL import Image
from torch import Size, Tensor, device as Device, dtype as DType, tensor
from refiners.fluxion.utils import interpolate, normalize, pad
def compute_scaled_size(size: tuple[int, int], image_encoder_resolution: int) -> tuple[int, int]:
"""Compute the scaled size as expected by the image encoder.
This computed size keep the ratio of the input image, and scale it to fit inside the square (image_encoder_resolution, image_encoder_resolution) of image encoder.
Args:
size: The size (h, w) of the input image.
image_encoder_resolution: Image encoder resolution.
Returns:
The target height.
The target width.
"""
oldh, oldw = size
scale = image_encoder_resolution * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)
def image_to_scaled_tensor(
image: Image.Image, scaled_size: tuple[int, int], device: Device | None = None, dtype: DType | None = None
) -> Tensor:
"""Resize the image to `scaled_size` and convert it to a tensor.
Args:
image: The image.
scaled_size: The target size (h, w).
device: Tensor device.
dtype: Tensor dtype.
Returns:
a Tensor of shape (1, c, h, w)
"""
h, w = scaled_size
resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore
return tensor(
np.array(resized).astype(np.float32).transpose(2, 0, 1),
device=device,
dtype=dtype,
).unsqueeze(0)
def preprocess_image(
image: Image.Image, image_encoder_resolution: int, device: Device | None = None, dtype: DType | None = None
) -> Tensor:
"""Preprocess an image without distorting its aspect ratio.
Args:
image: The image to preprocess before calling the image encoder.
image_encoder_resolution: Image encoder resolution.
device: Tensor device (None by default).
dtype: Tensor dtype (None by default).
Returns:
The preprocessed image.
"""
scaled_size = compute_scaled_size((image.height, image.width), image_encoder_resolution)
image_tensor = image_to_scaled_tensor(image, scaled_size, device, dtype)
return pad_image_tensor(
normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
scaled_size,
image_encoder_resolution,
)
def pad_image_tensor(image_tensor: Tensor, scaled_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:
"""Pad an image with zeros to make it square.
Args:
image_tensor: The image tensor to pad.
scaled_size: The scaled size (h, w).
image_encoder_resolution: Image encoder resolution.
Returns:
The padded image.
"""
assert len(image_tensor.shape) == 4
assert image_tensor.shape[2] <= image_encoder_resolution
assert image_tensor.shape[3] <= image_encoder_resolution
h, w = scaled_size
padh = image_encoder_resolution - h
padw = image_encoder_resolution - w
return pad(image_tensor, (0, padw, 0, padh))
def postprocess_masks(low_res_masks: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:
"""Postprocess the masks to fit the original image size and remove zero-padding (if any).
Args:
low_res_masks: The masks to postprocess.
original_size: The original size (h, w).
image_encoder_resolution: Image encoder resolution.
Returns:
The postprocessed masks.
"""
scaled_size = compute_scaled_size(original_size, image_encoder_resolution)
masks = interpolate(low_res_masks, size=Size((image_encoder_resolution, image_encoder_resolution)), mode="bilinear")
masks = masks[..., : scaled_size[0], : scaled_size[1]] # remove padding added at `preprocess_image` time
masks = interpolate(masks, size=Size(original_size), mode="bilinear")
return masks
def normalize_coordinates(coordinates: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:
"""Normalize the coordinates in the [0,1] range
Args:
coordinates: The coordinates to normalize.
original_size: The original image size.
image_encoder_resolution: Image encoder resolution.
Returns:
The normalized coordinates.
"""
scaled_size = compute_scaled_size(original_size, image_encoder_resolution)
coordinates[:, :, 0] = (
(coordinates[:, :, 0] * (scaled_size[1] / original_size[1])) + 0.5
) / image_encoder_resolution
coordinates[:, :, 1] = (
(coordinates[:, :, 1] * (scaled_size[0] / original_size[0])) + 0.5
) / image_encoder_resolution
return coordinates

View file

@ -0,0 +1,47 @@
import pytest
import torch
from PIL import Image
from refiners.foundationals.segment_anything.utils import (
compute_scaled_size,
image_to_scaled_tensor,
pad_image_tensor,
preprocess_image,
)
@pytest.fixture
def image_encoder_resolution() -> int:
return 1024
def test_compute_scaled_size(image_encoder_resolution: int) -> None:
w, h = (1536, 768)
scaled_size = compute_scaled_size((h, w), image_encoder_resolution)
assert scaled_size == (512, 1024)
def test_image_to_scaled_tensor() -> None:
image = Image.new("RGB", (1536, 768))
tensor = image_to_scaled_tensor(image, (512, 1024))
assert tensor.shape == (1, 3, 512, 1024)
def test_preprocess_image(image_encoder_resolution: int) -> None:
image = Image.new("RGB", (1536, 768))
preprocessed = preprocess_image(image, image_encoder_resolution)
assert preprocessed.shape == (1, 3, 1024, 1024)
def test_pad_image_tensor(image_encoder_resolution: int) -> None:
w, h = (1536, 768)
image = Image.new("RGB", (w, h), color="white")
scaled_size = compute_scaled_size((h, w), image_encoder_resolution)
scaled_image_tensor = image_to_scaled_tensor(image, scaled_size)
padded_image_tensor = pad_image_tensor(scaled_image_tensor, scaled_size, image_encoder_resolution)
assert padded_image_tensor.shape == (1, 3, 1024, 1024)
assert torch.all(padded_image_tensor[:, :, 512:, :] == 0)