mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
improve image_to_tensor and tensor_to_image utils
This commit is contained in:
parent
585c7ad55a
commit
6ddd901767
|
@ -102,13 +102,49 @@ def gaussian_blur(
|
|||
|
||||
|
||||
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
||||
return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(
|
||||
0
|
||||
)
|
||||
"""
|
||||
Convert a PIL Image to a Tensor.
|
||||
|
||||
If the image is in mode `RGB` the tensor will have shape `[3, H, W]`, otherwise
|
||||
`[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`.
|
||||
|
||||
Values are clamped to the range `[0, 1]`.
|
||||
"""
|
||||
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)
|
||||
|
||||
match image.mode:
|
||||
case "L":
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
case "RGBA" | "RGB":
|
||||
image_tensor = image_tensor.permute(2, 0, 1)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported image mode: {image.mode}")
|
||||
|
||||
return image_tensor.unsqueeze(0)
|
||||
|
||||
|
||||
def tensor_to_image(tensor: Tensor) -> Image.Image:
|
||||
return Image.fromarray((tensor.clamp(0, 1).squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")) # type: ignore
|
||||
"""
|
||||
Convert a Tensor to a PIL Image.
|
||||
|
||||
The tensor must have shape `[1, channels, height, width]` where the number of
|
||||
channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).
|
||||
|
||||
Expected values are in the range `[0, 1]` and are clamped to this range.
|
||||
"""
|
||||
assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
|
||||
num_channels = tensor.shape[1]
|
||||
tensor = tensor.clamp(0, 1).squeeze(0)
|
||||
|
||||
match num_channels:
|
||||
case 1:
|
||||
tensor = tensor.squeeze(0)
|
||||
case 3 | 4:
|
||||
tensor = tensor.permute(1, 2, 0)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported number of channels: {num_channels}")
|
||||
|
||||
return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType]
|
||||
|
||||
|
||||
def safe_open(
|
||||
|
|
|
@ -3,10 +3,11 @@ from warnings import warn
|
|||
|
||||
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
|
||||
from torch import device as Device, dtype as DType
|
||||
from PIL import Image
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from refiners.fluxion.utils import gaussian_blur, manual_seed
|
||||
from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, tensor_to_image
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -47,3 +48,17 @@ def test_gaussian_blur(test_device: Device, blur_input: BlurInput) -> None:
|
|||
our_blur = gaussian_blur(tensor, blur_input.kernel_size, blur_input.sigma)
|
||||
|
||||
assert torch.equal(our_blur, ref_blur)
|
||||
|
||||
|
||||
def test_image_to_tensor() -> None:
|
||||
image = Image.new("RGB", (512, 512))
|
||||
|
||||
assert image_to_tensor(image).shape == (1, 3, 512, 512)
|
||||
assert image_to_tensor(image.convert("L")).shape == (1, 1, 512, 512)
|
||||
assert image_to_tensor(image.convert("RGBA")).shape == (1, 4, 512, 512)
|
||||
|
||||
|
||||
def test_tensor_to_image() -> None:
|
||||
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB"
|
||||
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L"
|
||||
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"
|
||||
|
|
Loading…
Reference in a new issue