mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
improve image_to_tensor and tensor_to_image utils
This commit is contained in:
parent
585c7ad55a
commit
6ddd901767
|
@ -102,13 +102,49 @@ def gaussian_blur(
|
||||||
|
|
||||||
|
|
||||||
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
||||||
return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze(
|
"""
|
||||||
0
|
Convert a PIL Image to a Tensor.
|
||||||
)
|
|
||||||
|
If the image is in mode `RGB` the tensor will have shape `[3, H, W]`, otherwise
|
||||||
|
`[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`.
|
||||||
|
|
||||||
|
Values are clamped to the range `[0, 1]`.
|
||||||
|
"""
|
||||||
|
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
match image.mode:
|
||||||
|
case "L":
|
||||||
|
image_tensor = image_tensor.unsqueeze(0)
|
||||||
|
case "RGBA" | "RGB":
|
||||||
|
image_tensor = image_tensor.permute(2, 0, 1)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported image mode: {image.mode}")
|
||||||
|
|
||||||
|
return image_tensor.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
def tensor_to_image(tensor: Tensor) -> Image.Image:
|
def tensor_to_image(tensor: Tensor) -> Image.Image:
|
||||||
return Image.fromarray((tensor.clamp(0, 1).squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")) # type: ignore
|
"""
|
||||||
|
Convert a Tensor to a PIL Image.
|
||||||
|
|
||||||
|
The tensor must have shape `[1, channels, height, width]` where the number of
|
||||||
|
channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).
|
||||||
|
|
||||||
|
Expected values are in the range `[0, 1]` and are clamped to this range.
|
||||||
|
"""
|
||||||
|
assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
|
||||||
|
num_channels = tensor.shape[1]
|
||||||
|
tensor = tensor.clamp(0, 1).squeeze(0)
|
||||||
|
|
||||||
|
match num_channels:
|
||||||
|
case 1:
|
||||||
|
tensor = tensor.squeeze(0)
|
||||||
|
case 3 | 4:
|
||||||
|
tensor = tensor.permute(1, 2, 0)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported number of channels: {num_channels}")
|
||||||
|
|
||||||
|
return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType]
|
||||||
|
|
||||||
|
|
||||||
def safe_open(
|
def safe_open(
|
||||||
|
|
|
@ -3,10 +3,11 @@ from warnings import warn
|
||||||
|
|
||||||
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
|
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
|
||||||
from torch import device as Device, dtype as DType
|
from torch import device as Device, dtype as DType
|
||||||
|
from PIL import Image
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.fluxion.utils import gaussian_blur, manual_seed
|
from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, tensor_to_image
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -47,3 +48,17 @@ def test_gaussian_blur(test_device: Device, blur_input: BlurInput) -> None:
|
||||||
our_blur = gaussian_blur(tensor, blur_input.kernel_size, blur_input.sigma)
|
our_blur = gaussian_blur(tensor, blur_input.kernel_size, blur_input.sigma)
|
||||||
|
|
||||||
assert torch.equal(our_blur, ref_blur)
|
assert torch.equal(our_blur, ref_blur)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_to_tensor() -> None:
|
||||||
|
image = Image.new("RGB", (512, 512))
|
||||||
|
|
||||||
|
assert image_to_tensor(image).shape == (1, 3, 512, 512)
|
||||||
|
assert image_to_tensor(image.convert("L")).shape == (1, 1, 512, 512)
|
||||||
|
assert image_to_tensor(image.convert("RGBA")).shape == (1, 4, 512, 512)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_to_image() -> None:
|
||||||
|
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB"
|
||||||
|
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L"
|
||||||
|
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"
|
||||||
|
|
Loading…
Reference in a new issue