cast to float32 before converting to image in tensor_to_image to fix bfloat16 conversion

This commit is contained in:
limiteinductive 2024-01-15 15:50:40 +01:00 committed by Benjamin Trom
parent 7f722029be
commit d9ae7ca6a5
2 changed files with 2 additions and 0 deletions

View file

@ -146,6 +146,7 @@ def tensor_to_image(tensor: Tensor) -> Image.Image:
assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}" assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
num_channels = tensor.shape[1] num_channels = tensor.shape[1]
tensor = tensor.clamp(0, 1).squeeze(0) tensor = tensor.clamp(0, 1).squeeze(0)
tensor = tensor.to(torch.float32) # to avoid numpy error with bfloat16
match num_channels: match num_channels:
case 1: case 1:

View file

@ -69,6 +69,7 @@ def test_tensor_to_image() -> None:
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB" assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB"
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L" assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L"
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA" assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"
assert tensor_to_image(torch.zeros(1, 3, 512, 512, dtype=torch.bfloat16)).mode == "RGB"
def test_summarize_tensor() -> None: def test_summarize_tensor() -> None: