From bf7852b88eb40bc2db9aec268b799335436fe38f Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Mon, 15 Apr 2024 15:38:18 +0000 Subject: [PATCH] SAM: image_to_scaled_tensor gray images --- src/refiners/fluxion/utils.py | 2 +- .../foundationals/segment_anything/utils.py | 15 ++++++--------- .../foundationals/segment_anything/test_utils.py | 9 +++++++-- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 78ad45f..31bc41f 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -141,7 +141,7 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp If the image is in mode `RGB` the tensor will have shape `[3, H, W]`, otherwise `[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`. - Values are clamped to the range `[0, 1]`. + Values are normalized to the range `[0, 1]`. """ image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype) diff --git a/src/refiners/foundationals/segment_anything/utils.py b/src/refiners/foundationals/segment_anything/utils.py index 026e159..7058287 100644 --- a/src/refiners/foundationals/segment_anything/utils.py +++ b/src/refiners/foundationals/segment_anything/utils.py @@ -1,8 +1,7 @@ -import numpy as np from PIL import Image -from torch import Size, Tensor, device as Device, dtype as DType, tensor +from torch import Size, Tensor, device as Device, dtype as DType -from refiners.fluxion.utils import interpolate, normalize, pad +from refiners.fluxion.utils import image_to_tensor, interpolate, normalize, pad def compute_scaled_size(size: tuple[int, int], image_encoder_resolution: int) -> tuple[int, int]: @@ -40,11 +39,8 @@ def image_to_scaled_tensor( """ h, w = scaled_size resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore - return tensor( - np.array(resized).astype(np.float32).transpose(2, 0, 1), - device=device, - dtype=dtype, - ).unsqueeze(0) + + return image_to_tensor(resized, device=device, dtype=dtype) * 255.0 def preprocess_image( @@ -61,9 +57,10 @@ def preprocess_image( Returns: The preprocessed image. """ + scaled_size = compute_scaled_size((image.height, image.width), image_encoder_resolution) - image_tensor = image_to_scaled_tensor(image, scaled_size, device, dtype) + image_tensor = image_to_scaled_tensor(image, scaled_size, device=device, dtype=dtype) return pad_image_tensor( normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]), diff --git a/tests/foundationals/segment_anything/test_utils.py b/tests/foundationals/segment_anything/test_utils.py index 6966cf9..b907804 100644 --- a/tests/foundationals/segment_anything/test_utils.py +++ b/tests/foundationals/segment_anything/test_utils.py @@ -22,13 +22,18 @@ def test_compute_scaled_size(image_encoder_resolution: int) -> None: assert scaled_size == (512, 1024) -def test_image_to_scaled_tensor() -> None: +def test_rgb_image_to_scaled_tensor() -> None: image = Image.new("RGB", (1536, 768)) tensor = image_to_scaled_tensor(image, (512, 1024)) - assert tensor.shape == (1, 3, 512, 1024) +def test_grayscale_image_to_scaled_tensor() -> None: + image = Image.new("L", (1536, 768)) + tensor = image_to_scaled_tensor(image, (512, 1024)) + assert tensor.shape == (1, 1, 512, 1024) + + def test_preprocess_image(image_encoder_resolution: int) -> None: image = Image.new("RGB", (1536, 768)) preprocessed = preprocess_image(image, image_encoder_resolution)