diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 47789a2..09d9c06 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -146,6 +146,7 @@ def tensor_to_image(tensor: Tensor) -> Image.Image: assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}" num_channels = tensor.shape[1] tensor = tensor.clamp(0, 1).squeeze(0) + tensor = tensor.to(torch.float32) # to avoid numpy error with bfloat16 match num_channels: case 1: diff --git a/tests/fluxion/test_utils.py b/tests/fluxion/test_utils.py index 1edb04d..a86d0cd 100644 --- a/tests/fluxion/test_utils.py +++ b/tests/fluxion/test_utils.py @@ -69,6 +69,7 @@ def test_tensor_to_image() -> None: assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB" assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L" assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA" + assert tensor_to_image(torch.zeros(1, 3, 512, 512, dtype=torch.bfloat16)).mode == "RGB" def test_summarize_tensor() -> None: