improve image_to_tensor and tensor_to_image utils

This commit is contained in:
Benjamin Trom 2023-10-17 18:00:14 +02:00
parent 585c7ad55a
commit 6ddd901767
2 changed files with 56 additions and 5 deletions

View file

@ -102,13 +102,49 @@ def gaussian_blur(
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(
0
)
"""
Convert a PIL Image to a Tensor.
If the image is in mode `RGB` the tensor will have shape `[3, H, W]`, otherwise
`[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`.
Values are clamped to the range `[0, 1]`.
"""
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)
match image.mode:
case "L":
image_tensor = image_tensor.unsqueeze(0)
case "RGBA" | "RGB":
image_tensor = image_tensor.permute(2, 0, 1)
case _:
raise ValueError(f"Unsupported image mode: {image.mode}")
return image_tensor.unsqueeze(0)
def tensor_to_image(tensor: Tensor) -> Image.Image:
return Image.fromarray((tensor.clamp(0, 1).squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")) # type: ignore
"""
Convert a Tensor to a PIL Image.
The tensor must have shape `[1, channels, height, width]` where the number of
channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).
Expected values are in the range `[0, 1]` and are clamped to this range.
"""
assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
num_channels = tensor.shape[1]
tensor = tensor.clamp(0, 1).squeeze(0)
match num_channels:
case 1:
tensor = tensor.squeeze(0)
case 3 | 4:
tensor = tensor.permute(1, 2, 0)
case _:
raise ValueError(f"Unsupported number of channels: {num_channels}")
return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType]
def safe_open(

View file

@ -3,10 +3,11 @@ from warnings import warn
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
from torch import device as Device, dtype as DType
from PIL import Image
import pytest
import torch
from refiners.fluxion.utils import gaussian_blur, manual_seed
from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, tensor_to_image
@dataclass
@ -47,3 +48,17 @@ def test_gaussian_blur(test_device: Device, blur_input: BlurInput) -> None:
our_blur = gaussian_blur(tensor, blur_input.kernel_size, blur_input.sigma)
assert torch.equal(our_blur, ref_blur)
def test_image_to_tensor() -> None:
image = Image.new("RGB", (512, 512))
assert image_to_tensor(image).shape == (1, 3, 512, 512)
assert image_to_tensor(image.convert("L")).shape == (1, 1, 512, 512)
assert image_to_tensor(image.convert("RGBA")).shape == (1, 4, 512, 512)
def test_tensor_to_image() -> None:
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB"
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L"
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"