From 0ac290f67df1f3d2332fb990a5f434cea3954094 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Thu, 11 Apr 2024 13:09:58 +0000 Subject: [PATCH] SAM: expose sizing helpers --- .../foundationals/segment_anything/model.py | 102 ++++---------- .../foundationals/segment_anything/utils.py | 132 ++++++++++++++++++ .../segment_anything/test_utils.py | 47 +++++++ 3 files changed, 209 insertions(+), 72 deletions(-) create mode 100644 src/refiners/foundationals/segment_anything/utils.py create mode 100644 tests/foundationals/segment_anything/test_utils.py diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py index 1f94e70..ea4b44a 100644 --- a/src/refiners/foundationals/segment_anything/model.py +++ b/src/refiners/foundationals/segment_anything/model.py @@ -1,17 +1,21 @@ from dataclasses import dataclass from typing import Sequence -import numpy as np import torch from jaxtyping import Float from PIL import Image from torch import Tensor, device as Device, dtype as DType import refiners.fluxion.layers as fl -from refiners.fluxion.utils import interpolate, no_grad, normalize, pad +from refiners.fluxion.utils import no_grad from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder +from refiners.foundationals.segment_anything.utils import ( + normalize_coordinates, + postprocess_masks, + preprocess_image, +) @dataclass @@ -85,9 +89,8 @@ class SegmentAnything(fl.Chain): The computed image embedding. """ original_size = (image.height, image.width) - target_size = self.compute_target_size(original_size) return ImageEmbedding( - features=self.image_encoder(self.preprocess_image(image=image, target_size=target_size)), + features=self.image_encoder(self.preprocess_image(image)), original_image_size=original_size, ) @@ -118,12 +121,10 @@ class SegmentAnything(fl.Chain): """ if isinstance(input, ImageEmbedding): original_size = input.original_image_size - target_size = self.compute_target_size(original_size) image_embedding = input.features else: original_size = (input.height, input.width) - target_size = self.compute_target_size(original_size) - image_embedding = self.image_encoder(self.preprocess_image(image=input, target_size=target_size)) + image_embedding = self.image_encoder(self.preprocess_image(input)) coordinates, type_mask = self.point_encoder.points_to_tensor( foreground_points=foreground_points, @@ -139,9 +140,7 @@ class SegmentAnything(fl.Chain): image_embedding_size=self.image_encoder.image_embedding_size ) - point_embedding = self.point_encoder( - self.normalize(coordinates, target_size=target_size, original_size=original_size) - ) + point_embedding = self.point_encoder(self.normalize(coordinates, original_size=original_size)) dense_positional_embedding = self.point_encoder.get_dense_positional_embedding( image_embedding_size=self.image_encoder.image_embedding_size ) @@ -153,9 +152,7 @@ class SegmentAnything(fl.Chain): low_res_masks, iou_predictions = self.mask_decoder() - high_res_masks = self.postprocess_masks( - masks=low_res_masks, target_size=target_size, original_size=original_size - ) + high_res_masks = self.postprocess_masks(low_res_masks, original_size) if binarize: high_res_masks = high_res_masks > self.mask_threshold @@ -163,82 +160,43 @@ class SegmentAnything(fl.Chain): return high_res_masks, iou_predictions, low_res_masks @property - def image_size(self) -> int: - """The image size.""" + def image_encoder_resolution(self) -> int: + """The resolution of the image encoder.""" w, h = self.image_encoder.image_size assert w == h return w - def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]: - """Compute the target size as expected by the image encoder. - - Args: - size: The size of the input image. - - Returns: - The target height. - The target width. + def preprocess_image(self, image: Image.Image) -> Tensor: """ - oldh, oldw = size - scale = self.image_size * 1.0 / max(oldh, oldw) - newh, neww = oldh * scale, oldw * scale - neww = int(neww + 0.5) - newh = int(newh + 0.5) - return (newh, neww) - - def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor: - """Preprocess an image without distorting its aspect ratio. - + See [`preprocess_image`][refiners.foundationals.segment_anything.utils.preprocess_image] Args: image: The image to preprocess. - target_size: The target size. - Returns: - The preprocessed image. + The preprocessed tensor. """ - h, w = target_size - padh = self.image_size - h - padw = self.image_size - w - resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore - image_tensor = torch.tensor( - np.array(resized).astype(np.float32).transpose(2, 0, 1), - device=self.device, - dtype=self.dtype, - ).unsqueeze(0) - return pad( - normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]), (0, padw, 0, padh) - ) - - def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor: - """Normalize the coordinates. + return preprocess_image(image, self.image_encoder_resolution, self.device, self.dtype) + def normalize(self, coordinates: Tensor, original_size: tuple[int, int]) -> Tensor: + """ + See [`normalize_coordinates`][refiners.foundationals.segment_anything.utils.normalize_coordinates] Args: - coordinates: The coordinates to normalize. - target_size: The target size. - original_size: The original size. - + coordinates: a tensor of coordinates. + original_size: (h, w) the original size of the image. Returns: - The normalized coordinates. + The [0,1] normalized coordinates tensor. """ - coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size - coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size - return coordinates - - def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor: - """Postprocess the masks. + return normalize_coordinates(coordinates, original_size, self.image_encoder_resolution) + def postprocess_masks(self, low_res_masks: Tensor, original_size: tuple[int, int]) -> Tensor: + """ + See [`postprocess_masks`][refiners.foundationals.segment_anything.utils.postprocess_masks] Args: - masks: The masks to postprocess. - target_size: The target size. - original_size: The original size. - + low_res_masks: a mask tensor of size (N, 1, 256, 256) + original_size: (h, w) the original size of the image. Returns: - The postprocessed masks. + The mask of shape (N, 1, H, W) """ - masks = interpolate(masks, size=torch.Size((self.image_size, self.image_size)), mode="bilinear") - masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time - masks = interpolate(masks, size=torch.Size(original_size), mode="bilinear") - return masks + return postprocess_masks(low_res_masks, original_size, self.image_encoder_resolution) class SegmentAnythingH(SegmentAnything): diff --git a/src/refiners/foundationals/segment_anything/utils.py b/src/refiners/foundationals/segment_anything/utils.py new file mode 100644 index 0000000..026e159 --- /dev/null +++ b/src/refiners/foundationals/segment_anything/utils.py @@ -0,0 +1,132 @@ +import numpy as np +from PIL import Image +from torch import Size, Tensor, device as Device, dtype as DType, tensor + +from refiners.fluxion.utils import interpolate, normalize, pad + + +def compute_scaled_size(size: tuple[int, int], image_encoder_resolution: int) -> tuple[int, int]: + """Compute the scaled size as expected by the image encoder. + This computed size keep the ratio of the input image, and scale it to fit inside the square (image_encoder_resolution, image_encoder_resolution) of image encoder. + + Args: + size: The size (h, w) of the input image. + image_encoder_resolution: Image encoder resolution. + + Returns: + The target height. + The target width. + """ + oldh, oldw = size + scale = image_encoder_resolution * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) + + +def image_to_scaled_tensor( + image: Image.Image, scaled_size: tuple[int, int], device: Device | None = None, dtype: DType | None = None +) -> Tensor: + """Resize the image to `scaled_size` and convert it to a tensor. + + Args: + image: The image. + scaled_size: The target size (h, w). + device: Tensor device. + dtype: Tensor dtype. + Returns: + a Tensor of shape (1, c, h, w) + """ + h, w = scaled_size + resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore + return tensor( + np.array(resized).astype(np.float32).transpose(2, 0, 1), + device=device, + dtype=dtype, + ).unsqueeze(0) + + +def preprocess_image( + image: Image.Image, image_encoder_resolution: int, device: Device | None = None, dtype: DType | None = None +) -> Tensor: + """Preprocess an image without distorting its aspect ratio. + + Args: + image: The image to preprocess before calling the image encoder. + image_encoder_resolution: Image encoder resolution. + device: Tensor device (None by default). + dtype: Tensor dtype (None by default). + + Returns: + The preprocessed image. + """ + scaled_size = compute_scaled_size((image.height, image.width), image_encoder_resolution) + + image_tensor = image_to_scaled_tensor(image, scaled_size, device, dtype) + + return pad_image_tensor( + normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]), + scaled_size, + image_encoder_resolution, + ) + + +def pad_image_tensor(image_tensor: Tensor, scaled_size: tuple[int, int], image_encoder_resolution: int) -> Tensor: + """Pad an image with zeros to make it square. + + Args: + image_tensor: The image tensor to pad. + scaled_size: The scaled size (h, w). + image_encoder_resolution: Image encoder resolution. + + Returns: + The padded image. + """ + assert len(image_tensor.shape) == 4 + assert image_tensor.shape[2] <= image_encoder_resolution + assert image_tensor.shape[3] <= image_encoder_resolution + + h, w = scaled_size + padh = image_encoder_resolution - h + padw = image_encoder_resolution - w + return pad(image_tensor, (0, padw, 0, padh)) + + +def postprocess_masks(low_res_masks: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor: + """Postprocess the masks to fit the original image size and remove zero-padding (if any). + + Args: + low_res_masks: The masks to postprocess. + original_size: The original size (h, w). + image_encoder_resolution: Image encoder resolution. + + Returns: + The postprocessed masks. + """ + scaled_size = compute_scaled_size(original_size, image_encoder_resolution) + masks = interpolate(low_res_masks, size=Size((image_encoder_resolution, image_encoder_resolution)), mode="bilinear") + masks = masks[..., : scaled_size[0], : scaled_size[1]] # remove padding added at `preprocess_image` time + masks = interpolate(masks, size=Size(original_size), mode="bilinear") + return masks + + +def normalize_coordinates(coordinates: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor: + """Normalize the coordinates in the [0,1] range + + Args: + coordinates: The coordinates to normalize. + original_size: The original image size. + image_encoder_resolution: Image encoder resolution. + + Returns: + The normalized coordinates. + """ + scaled_size = compute_scaled_size(original_size, image_encoder_resolution) + coordinates[:, :, 0] = ( + (coordinates[:, :, 0] * (scaled_size[1] / original_size[1])) + 0.5 + ) / image_encoder_resolution + coordinates[:, :, 1] = ( + (coordinates[:, :, 1] * (scaled_size[0] / original_size[0])) + 0.5 + ) / image_encoder_resolution + return coordinates diff --git a/tests/foundationals/segment_anything/test_utils.py b/tests/foundationals/segment_anything/test_utils.py new file mode 100644 index 0000000..6966cf9 --- /dev/null +++ b/tests/foundationals/segment_anything/test_utils.py @@ -0,0 +1,47 @@ +import pytest +import torch +from PIL import Image + +from refiners.foundationals.segment_anything.utils import ( + compute_scaled_size, + image_to_scaled_tensor, + pad_image_tensor, + preprocess_image, +) + + +@pytest.fixture +def image_encoder_resolution() -> int: + return 1024 + + +def test_compute_scaled_size(image_encoder_resolution: int) -> None: + w, h = (1536, 768) + scaled_size = compute_scaled_size((h, w), image_encoder_resolution) + + assert scaled_size == (512, 1024) + + +def test_image_to_scaled_tensor() -> None: + image = Image.new("RGB", (1536, 768)) + tensor = image_to_scaled_tensor(image, (512, 1024)) + + assert tensor.shape == (1, 3, 512, 1024) + + +def test_preprocess_image(image_encoder_resolution: int) -> None: + image = Image.new("RGB", (1536, 768)) + preprocessed = preprocess_image(image, image_encoder_resolution) + + assert preprocessed.shape == (1, 3, 1024, 1024) + + +def test_pad_image_tensor(image_encoder_resolution: int) -> None: + w, h = (1536, 768) + image = Image.new("RGB", (w, h), color="white") + scaled_size = compute_scaled_size((h, w), image_encoder_resolution) + scaled_image_tensor = image_to_scaled_tensor(image, scaled_size) + padded_image_tensor = pad_image_tensor(scaled_image_tensor, scaled_size, image_encoder_resolution) + + assert padded_image_tensor.shape == (1, 3, 1024, 1024) + assert torch.all(padded_image_tensor[:, :, 512:, :] == 0)