fix: summarize_tensor(tensor) when tensor.numel() == 0

This commit is contained in:
Pierre Colle 2024-01-19 15:36:19 +01:00 committed by Cédric Deltheil
parent 2b4bc77534
commit 91aea9b7ff
2 changed files with 8 additions and 6 deletions

View file

@ -196,12 +196,13 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str:
if tensor.is_complex():
tensor_f = tensor.real.float()
else:
info_list.extend(
[
f"min={tensor.min():.2f}", # type: ignore
f"max={tensor.max():.2f}", # type: ignore
]
)
if tensor.numel() > 0:
info_list.extend(
[
f"min={tensor.min():.2f}", # type: ignore
f"max={tensor.max():.2f}", # type: ignore
]
)
tensor_f = tensor.float()
info_list.extend(

View file

@ -79,6 +79,7 @@ def test_summarize_tensor() -> None:
assert summarize_tensor(torch.complex(torch.zeros(1, 3, 512, 512), torch.zeros(1, 3, 512, 512)))
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bfloat16())
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bool())
assert summarize_tensor(torch.zeros(1, 0, 512, 512).int())
def test_no_grad() -> None: