improve image_to_tensor and tensor_to_image utils

This commit is contained in:
Benjamin Trom 2023-10-17 18:00:14 +02:00
parent 585c7ad55a
commit 6ddd901767
2 changed files with 56 additions and 5 deletions

View file

@ -102,13 +102,49 @@ def gaussian_blur(
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor: def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze( """
0 Convert a PIL Image to a Tensor.
)
If the image is in mode `RGB` the tensor will have shape `[3, H, W]`, otherwise
`[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`.
Values are clamped to the range `[0, 1]`.
"""
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)
match image.mode:
case "L":
image_tensor = image_tensor.unsqueeze(0)
case "RGBA" | "RGB":
image_tensor = image_tensor.permute(2, 0, 1)
case _:
raise ValueError(f"Unsupported image mode: {image.mode}")
return image_tensor.unsqueeze(0)
def tensor_to_image(tensor: Tensor) -> Image.Image: def tensor_to_image(tensor: Tensor) -> Image.Image:
return Image.fromarray((tensor.clamp(0, 1).squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")) # type: ignore """
Convert a Tensor to a PIL Image.
The tensor must have shape `[1, channels, height, width]` where the number of
channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).
Expected values are in the range `[0, 1]` and are clamped to this range.
"""
assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
num_channels = tensor.shape[1]
tensor = tensor.clamp(0, 1).squeeze(0)
match num_channels:
case 1:
tensor = tensor.squeeze(0)
case 3 | 4:
tensor = tensor.permute(1, 2, 0)
case _:
raise ValueError(f"Unsupported number of channels: {num_channels}")
return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType]
def safe_open( def safe_open(

View file

@ -3,10 +3,11 @@ from warnings import warn
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
from torch import device as Device, dtype as DType from torch import device as Device, dtype as DType
from PIL import Image
import pytest import pytest
import torch import torch
from refiners.fluxion.utils import gaussian_blur, manual_seed from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, tensor_to_image
@dataclass @dataclass
@ -47,3 +48,17 @@ def test_gaussian_blur(test_device: Device, blur_input: BlurInput) -> None:
our_blur = gaussian_blur(tensor, blur_input.kernel_size, blur_input.sigma) our_blur = gaussian_blur(tensor, blur_input.kernel_size, blur_input.sigma)
assert torch.equal(our_blur, ref_blur) assert torch.equal(our_blur, ref_blur)
def test_image_to_tensor() -> None:
image = Image.new("RGB", (512, 512))
assert image_to_tensor(image).shape == (1, 3, 512, 512)
assert image_to_tensor(image.convert("L")).shape == (1, 1, 512, 512)
assert image_to_tensor(image.convert("RGBA")).shape == (1, 4, 512, 512)
def test_tensor_to_image() -> None:
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB"
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L"
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"