mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
SAM: image_to_scaled_tensor gray images
This commit is contained in:
parent
f48712ee29
commit
bf7852b88e
|
@ -141,7 +141,7 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp
|
||||||
If the image is in mode `RGB` the tensor will have shape `[3, H, W]`,
|
If the image is in mode `RGB` the tensor will have shape `[3, H, W]`,
|
||||||
otherwise `[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`.
|
otherwise `[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`.
|
||||||
|
|
||||||
Values are clamped to the range `[0, 1]`.
|
Values are normalized to the range `[0, 1]`.
|
||||||
"""
|
"""
|
||||||
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)
|
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import Size, Tensor, device as Device, dtype as DType, tensor
|
from torch import Size, Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
from refiners.fluxion.utils import interpolate, normalize, pad
|
from refiners.fluxion.utils import image_to_tensor, interpolate, normalize, pad
|
||||||
|
|
||||||
|
|
||||||
def compute_scaled_size(size: tuple[int, int], image_encoder_resolution: int) -> tuple[int, int]:
|
def compute_scaled_size(size: tuple[int, int], image_encoder_resolution: int) -> tuple[int, int]:
|
||||||
|
@ -40,11 +39,8 @@ def image_to_scaled_tensor(
|
||||||
"""
|
"""
|
||||||
h, w = scaled_size
|
h, w = scaled_size
|
||||||
resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore
|
resized = image.resize((w, h), resample=Image.Resampling.BILINEAR) # type: ignore
|
||||||
return tensor(
|
|
||||||
np.array(resized).astype(np.float32).transpose(2, 0, 1),
|
return image_to_tensor(resized, device=device, dtype=dtype) * 255.0
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_image(
|
def preprocess_image(
|
||||||
|
@ -61,9 +57,10 @@ def preprocess_image(
|
||||||
Returns:
|
Returns:
|
||||||
The preprocessed image.
|
The preprocessed image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
scaled_size = compute_scaled_size((image.height, image.width), image_encoder_resolution)
|
scaled_size = compute_scaled_size((image.height, image.width), image_encoder_resolution)
|
||||||
|
|
||||||
image_tensor = image_to_scaled_tensor(image, scaled_size, device, dtype)
|
image_tensor = image_to_scaled_tensor(image, scaled_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
return pad_image_tensor(
|
return pad_image_tensor(
|
||||||
normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
|
normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
|
||||||
|
|
|
@ -22,13 +22,18 @@ def test_compute_scaled_size(image_encoder_resolution: int) -> None:
|
||||||
assert scaled_size == (512, 1024)
|
assert scaled_size == (512, 1024)
|
||||||
|
|
||||||
|
|
||||||
def test_image_to_scaled_tensor() -> None:
|
def test_rgb_image_to_scaled_tensor() -> None:
|
||||||
image = Image.new("RGB", (1536, 768))
|
image = Image.new("RGB", (1536, 768))
|
||||||
tensor = image_to_scaled_tensor(image, (512, 1024))
|
tensor = image_to_scaled_tensor(image, (512, 1024))
|
||||||
|
|
||||||
assert tensor.shape == (1, 3, 512, 1024)
|
assert tensor.shape == (1, 3, 512, 1024)
|
||||||
|
|
||||||
|
|
||||||
|
def test_grayscale_image_to_scaled_tensor() -> None:
|
||||||
|
image = Image.new("L", (1536, 768))
|
||||||
|
tensor = image_to_scaled_tensor(image, (512, 1024))
|
||||||
|
assert tensor.shape == (1, 1, 512, 1024)
|
||||||
|
|
||||||
|
|
||||||
def test_preprocess_image(image_encoder_resolution: int) -> None:
|
def test_preprocess_image(image_encoder_resolution: int) -> None:
|
||||||
image = Image.new("RGB", (1536, 768))
|
image = Image.new("RGB", (1536, 768))
|
||||||
preprocessed = preprocess_image(image, image_encoder_resolution)
|
preprocessed = preprocess_image(image, image_encoder_resolution)
|
||||||
|
|
Loading…
Reference in a new issue