mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
SAM: expose sizing helpers
This commit is contained in:
parent
06ff2f0a5f
commit
0ac290f67d
|
@ -1,17 +1,21 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.utils import interpolate, no_grad, normalize, pad
|
from refiners.fluxion.utils import no_grad
|
||||||
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
|
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
|
||||||
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||||
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
||||||
|
from refiners.foundationals.segment_anything.utils import (
|
||||||
|
normalize_coordinates,
|
||||||
|
postprocess_masks,
|
||||||
|
preprocess_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -85,9 +89,8 @@ class SegmentAnything(fl.Chain):
|
||||||
The computed image embedding.
|
The computed image embedding.
|
||||||
"""
|
"""
|
||||||
original_size = (image.height, image.width)
|
original_size = (image.height, image.width)
|
||||||
target_size = self.compute_target_size(original_size)
|
|
||||||
return ImageEmbedding(
|
return ImageEmbedding(
|
||||||
features=self.image_encoder(self.preprocess_image(image=image, target_size=target_size)),
|
features=self.image_encoder(self.preprocess_image(image)),
|
||||||
original_image_size=original_size,
|
original_image_size=original_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -118,12 +121,10 @@ class SegmentAnything(fl.Chain):
|
||||||
"""
|
"""
|
||||||
if isinstance(input, ImageEmbedding):
|
if isinstance(input, ImageEmbedding):
|
||||||
original_size = input.original_image_size
|
original_size = input.original_image_size
|
||||||
target_size = self.compute_target_size(original_size)
|
|
||||||
image_embedding = input.features
|
image_embedding = input.features
|
||||||
else:
|
else:
|
||||||
original_size = (input.height, input.width)
|
original_size = (input.height, input.width)
|
||||||
target_size = self.compute_target_size(original_size)
|
image_embedding = self.image_encoder(self.preprocess_image(input))
|
||||||
image_embedding = self.image_encoder(self.preprocess_image(image=input, target_size=target_size))
|
|
||||||
|
|
||||||
coordinates, type_mask = self.point_encoder.points_to_tensor(
|
coordinates, type_mask = self.point_encoder.points_to_tensor(
|
||||||
foreground_points=foreground_points,
|
foreground_points=foreground_points,
|
||||||
|
@ -139,9 +140,7 @@ class SegmentAnything(fl.Chain):
|
||||||
image_embedding_size=self.image_encoder.image_embedding_size
|
image_embedding_size=self.image_encoder.image_embedding_size
|
||||||
)
|
)
|
||||||
|
|
||||||
point_embedding = self.point_encoder(
|
point_embedding = self.point_encoder(self.normalize(coordinates, original_size=original_size))
|
||||||
self.normalize(coordinates, target_size=target_size, original_size=original_size)
|
|
||||||
)
|
|
||||||
dense_positional_embedding = self.point_encoder.get_dense_positional_embedding(
|
dense_positional_embedding = self.point_encoder.get_dense_positional_embedding(
|
||||||
image_embedding_size=self.image_encoder.image_embedding_size
|
image_embedding_size=self.image_encoder.image_embedding_size
|
||||||
)
|
)
|
||||||
|
@ -153,9 +152,7 @@ class SegmentAnything(fl.Chain):
|
||||||
|
|
||||||
low_res_masks, iou_predictions = self.mask_decoder()
|
low_res_masks, iou_predictions = self.mask_decoder()
|
||||||
|
|
||||||
high_res_masks = self.postprocess_masks(
|
high_res_masks = self.postprocess_masks(low_res_masks, original_size)
|
||||||
masks=low_res_masks, target_size=target_size, original_size=original_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if binarize:
|
if binarize:
|
||||||
high_res_masks = high_res_masks > self.mask_threshold
|
high_res_masks = high_res_masks > self.mask_threshold
|
||||||
|
@ -163,82 +160,43 @@ class SegmentAnything(fl.Chain):
|
||||||
return high_res_masks, iou_predictions, low_res_masks
|
return high_res_masks, iou_predictions, low_res_masks
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_size(self) -> int:
|
def image_encoder_resolution(self) -> int:
|
||||||
"""The image size."""
|
"""The resolution of the image encoder."""
|
||||||
w, h = self.image_encoder.image_size
|
w, h = self.image_encoder.image_size
|
||||||
assert w == h
|
assert w == h
|
||||||
return w
|
return w
|
||||||
|
|
||||||
def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]:
|
def preprocess_image(self, image: Image.Image) -> Tensor:
|
||||||
"""Compute the target size as expected by the image encoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
size: The size of the input image.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The target height.
|
|
||||||
The target width.
|
|
||||||
"""
|
"""
|
||||||
oldh, oldw = size
|
See [`preprocess_image`][refiners.foundationals.segment_anything.utils.preprocess_image]
|
||||||
scale = self.image_size * 1.0 / max(oldh, oldw)
|
|
||||||
newh, neww = oldh * scale, oldw * scale
|
|
||||||
neww = int(neww + 0.5)
|
|
||||||
newh = int(newh + 0.5)
|
|
||||||
return (newh, neww)
|
|
||||||
|
|
||||||
def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor:
|
|
||||||
"""Preprocess an image without distorting its aspect ratio.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: The image to preprocess.
|
image: The image to preprocess.
|
||||||
target_size: The target size.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The preprocessed image.
|
The preprocessed tensor.
|
||||||
"""
|
"""
|
||||||
h, w = target_size
|
return preprocess_image(image, self.image_encoder_resolution, self.device, self.dtype)
|
||||||
padh = self.image_size - h
|
|
||||||
padw = self.image_size - w
|
|
||||||
resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore
|
|
||||||
image_tensor = torch.tensor(
|
|
||||||
np.array(resized).astype(np.float32).transpose(2, 0, 1),
|
|
||||||
device=self.device,
|
|
||||||
dtype=self.dtype,
|
|
||||||
).unsqueeze(0)
|
|
||||||
return pad(
|
|
||||||
normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]), (0, padw, 0, padh)
|
|
||||||
)
|
|
||||||
|
|
||||||
def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
|
|
||||||
"""Normalize the coordinates.
|
|
||||||
|
|
||||||
|
def normalize(self, coordinates: Tensor, original_size: tuple[int, int]) -> Tensor:
|
||||||
|
"""
|
||||||
|
See [`normalize_coordinates`][refiners.foundationals.segment_anything.utils.normalize_coordinates]
|
||||||
Args:
|
Args:
|
||||||
coordinates: The coordinates to normalize.
|
coordinates: a tensor of coordinates.
|
||||||
target_size: The target size.
|
original_size: (h, w) the original size of the image.
|
||||||
original_size: The original size.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The normalized coordinates.
|
The [0,1] normalized coordinates tensor.
|
||||||
"""
|
"""
|
||||||
coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size
|
return normalize_coordinates(coordinates, original_size, self.image_encoder_resolution)
|
||||||
coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size
|
|
||||||
return coordinates
|
|
||||||
|
|
||||||
def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
|
|
||||||
"""Postprocess the masks.
|
|
||||||
|
|
||||||
|
def postprocess_masks(self, low_res_masks: Tensor, original_size: tuple[int, int]) -> Tensor:
|
||||||
|
"""
|
||||||
|
See [`postprocess_masks`][refiners.foundationals.segment_anything.utils.postprocess_masks]
|
||||||
Args:
|
Args:
|
||||||
masks: The masks to postprocess.
|
low_res_masks: a mask tensor of size (N, 1, 256, 256)
|
||||||
target_size: The target size.
|
original_size: (h, w) the original size of the image.
|
||||||
original_size: The original size.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The postprocessed masks.
|
The mask of shape (N, 1, H, W)
|
||||||
"""
|
"""
|
||||||
masks = interpolate(masks, size=torch.Size((self.image_size, self.image_size)), mode="bilinear")
|
return postprocess_masks(low_res_masks, original_size, self.image_encoder_resolution)
|
||||||
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
|
|
||||||
masks = interpolate(masks, size=torch.Size(original_size), mode="bilinear")
|
|
||||||
return masks
|
|
||||||
|
|
||||||
|
|
||||||
class SegmentAnythingH(SegmentAnything):
|
class SegmentAnythingH(SegmentAnything):
|
||||||
|
|
132
src/refiners/foundationals/segment_anything/utils.py
Normal file
132
src/refiners/foundationals/segment_anything/utils.py
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from torch import Size, Tensor, device as Device, dtype as DType, tensor
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import interpolate, normalize, pad
|
||||||
|
|
||||||
|
|
||||||
|
def compute_scaled_size(size: tuple[int, int], image_encoder_resolution: int) -> tuple[int, int]:
|
||||||
|
"""Compute the scaled size as expected by the image encoder.
|
||||||
|
This computed size keep the ratio of the input image, and scale it to fit inside the square (image_encoder_resolution, image_encoder_resolution) of image encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size: The size (h, w) of the input image.
|
||||||
|
image_encoder_resolution: Image encoder resolution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The target height.
|
||||||
|
The target width.
|
||||||
|
"""
|
||||||
|
oldh, oldw = size
|
||||||
|
scale = image_encoder_resolution * 1.0 / max(oldh, oldw)
|
||||||
|
newh, neww = oldh * scale, oldw * scale
|
||||||
|
neww = int(neww + 0.5)
|
||||||
|
newh = int(newh + 0.5)
|
||||||
|
return (newh, neww)
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_scaled_tensor(
|
||||||
|
image: Image.Image, scaled_size: tuple[int, int], device: Device | None = None, dtype: DType | None = None
|
||||||
|
) -> Tensor:
|
||||||
|
"""Resize the image to `scaled_size` and convert it to a tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: The image.
|
||||||
|
scaled_size: The target size (h, w).
|
||||||
|
device: Tensor device.
|
||||||
|
dtype: Tensor dtype.
|
||||||
|
Returns:
|
||||||
|
a Tensor of shape (1, c, h, w)
|
||||||
|
"""
|
||||||
|
h, w = scaled_size
|
||||||
|
resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore
|
||||||
|
return tensor(
|
||||||
|
np.array(resized).astype(np.float32).transpose(2, 0, 1),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(
|
||||||
|
image: Image.Image, image_encoder_resolution: int, device: Device | None = None, dtype: DType | None = None
|
||||||
|
) -> Tensor:
|
||||||
|
"""Preprocess an image without distorting its aspect ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: The image to preprocess before calling the image encoder.
|
||||||
|
image_encoder_resolution: Image encoder resolution.
|
||||||
|
device: Tensor device (None by default).
|
||||||
|
dtype: Tensor dtype (None by default).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The preprocessed image.
|
||||||
|
"""
|
||||||
|
scaled_size = compute_scaled_size((image.height, image.width), image_encoder_resolution)
|
||||||
|
|
||||||
|
image_tensor = image_to_scaled_tensor(image, scaled_size, device, dtype)
|
||||||
|
|
||||||
|
return pad_image_tensor(
|
||||||
|
normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
|
||||||
|
scaled_size,
|
||||||
|
image_encoder_resolution,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_image_tensor(image_tensor: Tensor, scaled_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:
|
||||||
|
"""Pad an image with zeros to make it square.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_tensor: The image tensor to pad.
|
||||||
|
scaled_size: The scaled size (h, w).
|
||||||
|
image_encoder_resolution: Image encoder resolution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The padded image.
|
||||||
|
"""
|
||||||
|
assert len(image_tensor.shape) == 4
|
||||||
|
assert image_tensor.shape[2] <= image_encoder_resolution
|
||||||
|
assert image_tensor.shape[3] <= image_encoder_resolution
|
||||||
|
|
||||||
|
h, w = scaled_size
|
||||||
|
padh = image_encoder_resolution - h
|
||||||
|
padw = image_encoder_resolution - w
|
||||||
|
return pad(image_tensor, (0, padw, 0, padh))
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_masks(low_res_masks: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:
|
||||||
|
"""Postprocess the masks to fit the original image size and remove zero-padding (if any).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
low_res_masks: The masks to postprocess.
|
||||||
|
original_size: The original size (h, w).
|
||||||
|
image_encoder_resolution: Image encoder resolution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The postprocessed masks.
|
||||||
|
"""
|
||||||
|
scaled_size = compute_scaled_size(original_size, image_encoder_resolution)
|
||||||
|
masks = interpolate(low_res_masks, size=Size((image_encoder_resolution, image_encoder_resolution)), mode="bilinear")
|
||||||
|
masks = masks[..., : scaled_size[0], : scaled_size[1]] # remove padding added at `preprocess_image` time
|
||||||
|
masks = interpolate(masks, size=Size(original_size), mode="bilinear")
|
||||||
|
return masks
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_coordinates(coordinates: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:
|
||||||
|
"""Normalize the coordinates in the [0,1] range
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coordinates: The coordinates to normalize.
|
||||||
|
original_size: The original image size.
|
||||||
|
image_encoder_resolution: Image encoder resolution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The normalized coordinates.
|
||||||
|
"""
|
||||||
|
scaled_size = compute_scaled_size(original_size, image_encoder_resolution)
|
||||||
|
coordinates[:, :, 0] = (
|
||||||
|
(coordinates[:, :, 0] * (scaled_size[1] / original_size[1])) + 0.5
|
||||||
|
) / image_encoder_resolution
|
||||||
|
coordinates[:, :, 1] = (
|
||||||
|
(coordinates[:, :, 1] * (scaled_size[0] / original_size[0])) + 0.5
|
||||||
|
) / image_encoder_resolution
|
||||||
|
return coordinates
|
47
tests/foundationals/segment_anything/test_utils.py
Normal file
47
tests/foundationals/segment_anything/test_utils.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from refiners.foundationals.segment_anything.utils import (
|
||||||
|
compute_scaled_size,
|
||||||
|
image_to_scaled_tensor,
|
||||||
|
pad_image_tensor,
|
||||||
|
preprocess_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def image_encoder_resolution() -> int:
|
||||||
|
return 1024
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_scaled_size(image_encoder_resolution: int) -> None:
|
||||||
|
w, h = (1536, 768)
|
||||||
|
scaled_size = compute_scaled_size((h, w), image_encoder_resolution)
|
||||||
|
|
||||||
|
assert scaled_size == (512, 1024)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_to_scaled_tensor() -> None:
|
||||||
|
image = Image.new("RGB", (1536, 768))
|
||||||
|
tensor = image_to_scaled_tensor(image, (512, 1024))
|
||||||
|
|
||||||
|
assert tensor.shape == (1, 3, 512, 1024)
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess_image(image_encoder_resolution: int) -> None:
|
||||||
|
image = Image.new("RGB", (1536, 768))
|
||||||
|
preprocessed = preprocess_image(image, image_encoder_resolution)
|
||||||
|
|
||||||
|
assert preprocessed.shape == (1, 3, 1024, 1024)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pad_image_tensor(image_encoder_resolution: int) -> None:
|
||||||
|
w, h = (1536, 768)
|
||||||
|
image = Image.new("RGB", (w, h), color="white")
|
||||||
|
scaled_size = compute_scaled_size((h, w), image_encoder_resolution)
|
||||||
|
scaled_image_tensor = image_to_scaled_tensor(image, scaled_size)
|
||||||
|
padded_image_tensor = pad_image_tensor(scaled_image_tensor, scaled_size, image_encoder_resolution)
|
||||||
|
|
||||||
|
assert padded_image_tensor.shape == (1, 3, 1024, 1024)
|
||||||
|
assert torch.all(padded_image_tensor[:, :, 512:, :] == 0)
|
Loading…
Reference in a new issue