mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
fix: summarize_tensor(tensor) when tensor.numel() == 0
This commit is contained in:
parent
2b4bc77534
commit
91aea9b7ff
|
@ -196,12 +196,13 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str:
|
|||
if tensor.is_complex():
|
||||
tensor_f = tensor.real.float()
|
||||
else:
|
||||
info_list.extend(
|
||||
[
|
||||
f"min={tensor.min():.2f}", # type: ignore
|
||||
f"max={tensor.max():.2f}", # type: ignore
|
||||
]
|
||||
)
|
||||
if tensor.numel() > 0:
|
||||
info_list.extend(
|
||||
[
|
||||
f"min={tensor.min():.2f}", # type: ignore
|
||||
f"max={tensor.max():.2f}", # type: ignore
|
||||
]
|
||||
)
|
||||
tensor_f = tensor.float()
|
||||
|
||||
info_list.extend(
|
||||
|
|
|
@ -79,6 +79,7 @@ def test_summarize_tensor() -> None:
|
|||
assert summarize_tensor(torch.complex(torch.zeros(1, 3, 512, 512), torch.zeros(1, 3, 512, 512)))
|
||||
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bfloat16())
|
||||
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bool())
|
||||
assert summarize_tensor(torch.zeros(1, 0, 512, 512).int())
|
||||
|
||||
|
||||
def test_no_grad() -> None:
|
||||
|
|
Loading…
Reference in a new issue