SAM: image_to_scaled_tensor gray images

This commit is contained in:
Pierre Colle 2024-04-15 15:38:18 +00:00 committed by Colle
parent f48712ee29
commit bf7852b88e
3 changed files with 14 additions and 12 deletions

View file

@ -141,7 +141,7 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp
If the image is in mode `RGB` the tensor will have shape `[3, H, W]`,
otherwise `[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`.
Values are clamped to the range `[0, 1]`.
Values are normalized to the range `[0, 1]`.
"""
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)

View file

@ -1,8 +1,7 @@
import numpy as np
from PIL import Image
from torch import Size, Tensor, device as Device, dtype as DType, tensor
from torch import Size, Tensor, device as Device, dtype as DType
from refiners.fluxion.utils import interpolate, normalize, pad
from refiners.fluxion.utils import image_to_tensor, interpolate, normalize, pad
def compute_scaled_size(size: tuple[int, int], image_encoder_resolution: int) -> tuple[int, int]:
@ -40,11 +39,8 @@ def image_to_scaled_tensor(
"""
h, w = scaled_size
resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore
return tensor(
np.array(resized).astype(np.float32).transpose(2, 0, 1),
device=device,
dtype=dtype,
).unsqueeze(0)
return image_to_tensor(resized, device=device, dtype=dtype) * 255.0
def preprocess_image(
@ -61,9 +57,10 @@ def preprocess_image(
Returns:
The preprocessed image.
"""
scaled_size = compute_scaled_size((image.height, image.width), image_encoder_resolution)
image_tensor = image_to_scaled_tensor(image, scaled_size, device, dtype)
image_tensor = image_to_scaled_tensor(image, scaled_size, device=device, dtype=dtype)
return pad_image_tensor(
normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),

View file

@ -22,13 +22,18 @@ def test_compute_scaled_size(image_encoder_resolution: int) -> None:
assert scaled_size == (512, 1024)
def test_image_to_scaled_tensor() -> None:
def test_rgb_image_to_scaled_tensor() -> None:
image = Image.new("RGB", (1536, 768))
tensor = image_to_scaled_tensor(image, (512, 1024))
assert tensor.shape == (1, 3, 512, 1024)
def test_grayscale_image_to_scaled_tensor() -> None:
image = Image.new("L", (1536, 768))
tensor = image_to_scaled_tensor(image, (512, 1024))
assert tensor.shape == (1, 1, 512, 1024)
def test_preprocess_image(image_encoder_resolution: int) -> None:
image = Image.new("RGB", (1536, 768))
preprocessed = preprocess_image(image, image_encoder_resolution)