fix: summarize_tensor(tensor) when tensor.numel() == 0

This commit is contained in:
Pierre Colle 2024-01-19 15:36:19 +01:00 committed by Cédric Deltheil
parent 2b4bc77534
commit 91aea9b7ff
2 changed files with 8 additions and 6 deletions

View file

@ -196,12 +196,13 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str:
if tensor.is_complex(): if tensor.is_complex():
tensor_f = tensor.real.float() tensor_f = tensor.real.float()
else: else:
info_list.extend( if tensor.numel() > 0:
[ info_list.extend(
f"min={tensor.min():.2f}", # type: ignore [
f"max={tensor.max():.2f}", # type: ignore f"min={tensor.min():.2f}", # type: ignore
] f"max={tensor.max():.2f}", # type: ignore
) ]
)
tensor_f = tensor.float() tensor_f = tensor.float()
info_list.extend( info_list.extend(

View file

@ -79,6 +79,7 @@ def test_summarize_tensor() -> None:
assert summarize_tensor(torch.complex(torch.zeros(1, 3, 512, 512), torch.zeros(1, 3, 512, 512))) assert summarize_tensor(torch.complex(torch.zeros(1, 3, 512, 512), torch.zeros(1, 3, 512, 512)))
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bfloat16()) assert summarize_tensor(torch.zeros(1, 3, 512, 512).bfloat16())
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bool()) assert summarize_tensor(torch.zeros(1, 3, 512, 512).bool())
assert summarize_tensor(torch.zeros(1, 0, 512, 512).int())
def test_no_grad() -> None: def test_no_grad() -> None: