mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
fix: summarize_tensor(tensor) when tensor.numel() == 0
This commit is contained in:
parent
2b4bc77534
commit
91aea9b7ff
|
@ -196,12 +196,13 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str:
|
||||||
if tensor.is_complex():
|
if tensor.is_complex():
|
||||||
tensor_f = tensor.real.float()
|
tensor_f = tensor.real.float()
|
||||||
else:
|
else:
|
||||||
info_list.extend(
|
if tensor.numel() > 0:
|
||||||
[
|
info_list.extend(
|
||||||
f"min={tensor.min():.2f}", # type: ignore
|
[
|
||||||
f"max={tensor.max():.2f}", # type: ignore
|
f"min={tensor.min():.2f}", # type: ignore
|
||||||
]
|
f"max={tensor.max():.2f}", # type: ignore
|
||||||
)
|
]
|
||||||
|
)
|
||||||
tensor_f = tensor.float()
|
tensor_f = tensor.float()
|
||||||
|
|
||||||
info_list.extend(
|
info_list.extend(
|
||||||
|
|
|
@ -79,6 +79,7 @@ def test_summarize_tensor() -> None:
|
||||||
assert summarize_tensor(torch.complex(torch.zeros(1, 3, 512, 512), torch.zeros(1, 3, 512, 512)))
|
assert summarize_tensor(torch.complex(torch.zeros(1, 3, 512, 512), torch.zeros(1, 3, 512, 512)))
|
||||||
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bfloat16())
|
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bfloat16())
|
||||||
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bool())
|
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bool())
|
||||||
|
assert summarize_tensor(torch.zeros(1, 0, 512, 512).int())
|
||||||
|
|
||||||
|
|
||||||
def test_no_grad() -> None:
|
def test_no_grad() -> None:
|
||||||
|
|
Loading…
Reference in a new issue