SAM: expose sizing helpers

This commit is contained in:
Pierre Colle 2024-04-11 13:09:58 +00:00 committed by Colle
parent 06ff2f0a5f
commit 0ac290f67d
3 changed files with 209 additions and 72 deletions

View file

@ -1,17 +1,21 @@
from dataclasses import dataclass
from typing import Sequence
import numpy as np
import torch
from jaxtyping import Float
from PIL import Image
from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import interpolate, no_grad, normalize, pad
from refiners.fluxion.utils import no_grad
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
from refiners.foundationals.segment_anything.utils import (
normalize_coordinates,
postprocess_masks,
preprocess_image,
)
@dataclass
@ -85,9 +89,8 @@ class SegmentAnything(fl.Chain):
The computed image embedding.
"""
original_size = (image.height, image.width)
target_size = self.compute_target_size(original_size)
return ImageEmbedding(
features=self.image_encoder(self.preprocess_image(image=image, target_size=target_size)),
features=self.image_encoder(self.preprocess_image(image)),
original_image_size=original_size,
)
@ -118,12 +121,10 @@ class SegmentAnything(fl.Chain):
"""
if isinstance(input, ImageEmbedding):
original_size = input.original_image_size
target_size = self.compute_target_size(original_size)
image_embedding = input.features
else:
original_size = (input.height, input.width)
target_size = self.compute_target_size(original_size)
image_embedding = self.image_encoder(self.preprocess_image(image=input, target_size=target_size))
image_embedding = self.image_encoder(self.preprocess_image(input))
coordinates, type_mask = self.point_encoder.points_to_tensor(
foreground_points=foreground_points,
@ -139,9 +140,7 @@ class SegmentAnything(fl.Chain):
image_embedding_size=self.image_encoder.image_embedding_size
)
point_embedding = self.point_encoder(
self.normalize(coordinates, target_size=target_size, original_size=original_size)
)
point_embedding = self.point_encoder(self.normalize(coordinates, original_size=original_size))
dense_positional_embedding = self.point_encoder.get_dense_positional_embedding(
image_embedding_size=self.image_encoder.image_embedding_size
)
@ -153,9 +152,7 @@ class SegmentAnything(fl.Chain):
low_res_masks, iou_predictions = self.mask_decoder()
high_res_masks = self.postprocess_masks(
masks=low_res_masks, target_size=target_size, original_size=original_size
)
high_res_masks = self.postprocess_masks(low_res_masks, original_size)
if binarize:
high_res_masks = high_res_masks > self.mask_threshold
@ -163,82 +160,43 @@ class SegmentAnything(fl.Chain):
return high_res_masks, iou_predictions, low_res_masks
@property
def image_size(self) -> int:
"""The image size."""
def image_encoder_resolution(self) -> int:
"""The resolution of the image encoder."""
w, h = self.image_encoder.image_size
assert w == h
return w
def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]:
"""Compute the target size as expected by the image encoder.
Args:
size: The size of the input image.
Returns:
The target height.
The target width.
def preprocess_image(self, image: Image.Image) -> Tensor:
"""
oldh, oldw = size
scale = self.image_size * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)
def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor:
"""Preprocess an image without distorting its aspect ratio.
See [`preprocess_image`][refiners.foundationals.segment_anything.utils.preprocess_image]
Args:
image: The image to preprocess.
target_size: The target size.
Returns:
The preprocessed image.
The preprocessed tensor.
"""
h, w = target_size
padh = self.image_size - h
padw = self.image_size - w
resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore
image_tensor = torch.tensor(
np.array(resized).astype(np.float32).transpose(2, 0, 1),
device=self.device,
dtype=self.dtype,
).unsqueeze(0)
return pad(
normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]), (0, padw, 0, padh)
)
def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
"""Normalize the coordinates.
return preprocess_image(image, self.image_encoder_resolution, self.device, self.dtype)
def normalize(self, coordinates: Tensor, original_size: tuple[int, int]) -> Tensor:
"""
See [`normalize_coordinates`][refiners.foundationals.segment_anything.utils.normalize_coordinates]
Args:
coordinates: The coordinates to normalize.
target_size: The target size.
original_size: The original size.
coordinates: a tensor of coordinates.
original_size: (h, w) the original size of the image.
Returns:
The normalized coordinates.
The [0,1] normalized coordinates tensor.
"""
coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size
coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size
return coordinates
def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
"""Postprocess the masks.
return normalize_coordinates(coordinates, original_size, self.image_encoder_resolution)
def postprocess_masks(self, low_res_masks: Tensor, original_size: tuple[int, int]) -> Tensor:
"""
See [`postprocess_masks`][refiners.foundationals.segment_anything.utils.postprocess_masks]
Args:
masks: The masks to postprocess.
target_size: The target size.
original_size: The original size.
low_res_masks: a mask tensor of size (N, 1, 256, 256)
original_size: (h, w) the original size of the image.
Returns:
The postprocessed masks.
The mask of shape (N, 1, H, W)
"""
masks = interpolate(masks, size=torch.Size((self.image_size, self.image_size)), mode="bilinear")
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
masks = interpolate(masks, size=torch.Size(original_size), mode="bilinear")
return masks
return postprocess_masks(low_res_masks, original_size, self.image_encoder_resolution)
class SegmentAnythingH(SegmentAnything):

View file

@ -0,0 +1,132 @@
import numpy as np
from PIL import Image
from torch import Size, Tensor, device as Device, dtype as DType, tensor
from refiners.fluxion.utils import interpolate, normalize, pad
def compute_scaled_size(size: tuple[int, int], image_encoder_resolution: int) -> tuple[int, int]:
"""Compute the scaled size as expected by the image encoder.
This computed size keep the ratio of the input image, and scale it to fit inside the square (image_encoder_resolution, image_encoder_resolution) of image encoder.
Args:
size: The size (h, w) of the input image.
image_encoder_resolution: Image encoder resolution.
Returns:
The target height.
The target width.
"""
oldh, oldw = size
scale = image_encoder_resolution * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)
def image_to_scaled_tensor(
image: Image.Image, scaled_size: tuple[int, int], device: Device | None = None, dtype: DType | None = None
) -> Tensor:
"""Resize the image to `scaled_size` and convert it to a tensor.
Args:
image: The image.
scaled_size: The target size (h, w).
device: Tensor device.
dtype: Tensor dtype.
Returns:
a Tensor of shape (1, c, h, w)
"""
h, w = scaled_size
resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore
return tensor(
np.array(resized).astype(np.float32).transpose(2, 0, 1),
device=device,
dtype=dtype,
).unsqueeze(0)
def preprocess_image(
image: Image.Image, image_encoder_resolution: int, device: Device | None = None, dtype: DType | None = None
) -> Tensor:
"""Preprocess an image without distorting its aspect ratio.
Args:
image: The image to preprocess before calling the image encoder.
image_encoder_resolution: Image encoder resolution.
device: Tensor device (None by default).
dtype: Tensor dtype (None by default).
Returns:
The preprocessed image.
"""
scaled_size = compute_scaled_size((image.height, image.width), image_encoder_resolution)
image_tensor = image_to_scaled_tensor(image, scaled_size, device, dtype)
return pad_image_tensor(
normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
scaled_size,
image_encoder_resolution,
)
def pad_image_tensor(image_tensor: Tensor, scaled_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:
"""Pad an image with zeros to make it square.
Args:
image_tensor: The image tensor to pad.
scaled_size: The scaled size (h, w).
image_encoder_resolution: Image encoder resolution.
Returns:
The padded image.
"""
assert len(image_tensor.shape) == 4
assert image_tensor.shape[2] <= image_encoder_resolution
assert image_tensor.shape[3] <= image_encoder_resolution
h, w = scaled_size
padh = image_encoder_resolution - h
padw = image_encoder_resolution - w
return pad(image_tensor, (0, padw, 0, padh))
def postprocess_masks(low_res_masks: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:
"""Postprocess the masks to fit the original image size and remove zero-padding (if any).
Args:
low_res_masks: The masks to postprocess.
original_size: The original size (h, w).
image_encoder_resolution: Image encoder resolution.
Returns:
The postprocessed masks.
"""
scaled_size = compute_scaled_size(original_size, image_encoder_resolution)
masks = interpolate(low_res_masks, size=Size((image_encoder_resolution, image_encoder_resolution)), mode="bilinear")
masks = masks[..., : scaled_size[0], : scaled_size[1]] # remove padding added at `preprocess_image` time
masks = interpolate(masks, size=Size(original_size), mode="bilinear")
return masks
def normalize_coordinates(coordinates: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:
"""Normalize the coordinates in the [0,1] range
Args:
coordinates: The coordinates to normalize.
original_size: The original image size.
image_encoder_resolution: Image encoder resolution.
Returns:
The normalized coordinates.
"""
scaled_size = compute_scaled_size(original_size, image_encoder_resolution)
coordinates[:, :, 0] = (
(coordinates[:, :, 0] * (scaled_size[1] / original_size[1])) + 0.5
) / image_encoder_resolution
coordinates[:, :, 1] = (
(coordinates[:, :, 1] * (scaled_size[0] / original_size[0])) + 0.5
) / image_encoder_resolution
return coordinates

View file

@ -0,0 +1,47 @@
import pytest
import torch
from PIL import Image
from refiners.foundationals.segment_anything.utils import (
compute_scaled_size,
image_to_scaled_tensor,
pad_image_tensor,
preprocess_image,
)
@pytest.fixture
def image_encoder_resolution() -> int:
return 1024
def test_compute_scaled_size(image_encoder_resolution: int) -> None:
w, h = (1536, 768)
scaled_size = compute_scaled_size((h, w), image_encoder_resolution)
assert scaled_size == (512, 1024)
def test_image_to_scaled_tensor() -> None:
image = Image.new("RGB", (1536, 768))
tensor = image_to_scaled_tensor(image, (512, 1024))
assert tensor.shape == (1, 3, 512, 1024)
def test_preprocess_image(image_encoder_resolution: int) -> None:
image = Image.new("RGB", (1536, 768))
preprocessed = preprocess_image(image, image_encoder_resolution)
assert preprocessed.shape == (1, 3, 1024, 1024)
def test_pad_image_tensor(image_encoder_resolution: int) -> None:
w, h = (1536, 768)
image = Image.new("RGB", (w, h), color="white")
scaled_size = compute_scaled_size((h, w), image_encoder_resolution)
scaled_image_tensor = image_to_scaled_tensor(image, scaled_size)
padded_image_tensor = pad_image_tensor(scaled_image_tensor, scaled_size, image_encoder_resolution)
assert padded_image_tensor.shape == (1, 3, 1024, 1024)
assert torch.all(padded_image_tensor[:, :, 512:, :] == 0)