diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index c55b0a6..3f5e710 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -102,13 +102,49 @@ def gaussian_blur( def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor: - return torch.tensor(array(image).astype(float32).transpose(2, 0, 1) / 255.0, device=device, dtype=dtype).unsqueeze( - 0 - ) + """ + Convert a PIL Image to a Tensor. + + If the image is in mode `RGB` the tensor will have shape `[3, H, W]`, otherwise + `[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`. + + Values are clamped to the range `[0, 1]`. + """ + image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype) + + match image.mode: + case "L": + image_tensor = image_tensor.unsqueeze(0) + case "RGBA" | "RGB": + image_tensor = image_tensor.permute(2, 0, 1) + case _: + raise ValueError(f"Unsupported image mode: {image.mode}") + + return image_tensor.unsqueeze(0) def tensor_to_image(tensor: Tensor) -> Image.Image: - return Image.fromarray((tensor.clamp(0, 1).squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")) # type: ignore + """ + Convert a Tensor to a PIL Image. + + The tensor must have shape `[1, channels, height, width]` where the number of + channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA). + + Expected values are in the range `[0, 1]` and are clamped to this range. + """ + assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}" + num_channels = tensor.shape[1] + tensor = tensor.clamp(0, 1).squeeze(0) + + match num_channels: + case 1: + tensor = tensor.squeeze(0) + case 3 | 4: + tensor = tensor.permute(1, 2, 0) + case _: + raise ValueError(f"Unsupported number of channels: {num_channels}") + + return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType] def safe_open( diff --git a/tests/fluxion/test_utils.py b/tests/fluxion/test_utils.py index c34c169..e83b789 100644 --- a/tests/fluxion/test_utils.py +++ b/tests/fluxion/test_utils.py @@ -3,10 +3,11 @@ from warnings import warn from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore from torch import device as Device, dtype as DType +from PIL import Image import pytest import torch -from refiners.fluxion.utils import gaussian_blur, manual_seed +from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, tensor_to_image @dataclass @@ -47,3 +48,17 @@ def test_gaussian_blur(test_device: Device, blur_input: BlurInput) -> None: our_blur = gaussian_blur(tensor, blur_input.kernel_size, blur_input.sigma) assert torch.equal(our_blur, ref_blur) + + +def test_image_to_tensor() -> None: + image = Image.new("RGB", (512, 512)) + + assert image_to_tensor(image).shape == (1, 3, 512, 512) + assert image_to_tensor(image.convert("L")).shape == (1, 1, 512, 512) + assert image_to_tensor(image.convert("RGBA")).shape == (1, 4, 512, 512) + + +def test_tensor_to_image() -> None: + assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB" + assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L" + assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"