mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
summarize_tensor: fix minor warning
Calling `tensor.float()` on a complex tensor raises a warning: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at ../aten/src/ATen/native/Copy.cpp:299.) Follow up of #171
This commit is contained in:
parent
aa9f572611
commit
fde61757fb
|
@ -193,19 +193,22 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str:
|
||||||
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
|
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
|
||||||
f"device={tensor.device}",
|
f"device={tensor.device}",
|
||||||
]
|
]
|
||||||
if not tensor.is_complex():
|
if tensor.is_complex():
|
||||||
|
tensor_f = tensor.real.float()
|
||||||
|
else:
|
||||||
info_list.extend(
|
info_list.extend(
|
||||||
[
|
[
|
||||||
f"min={tensor.min():.2f}", # type: ignore
|
f"min={tensor.min():.2f}", # type: ignore
|
||||||
f"max={tensor.max():.2f}", # type: ignore
|
f"max={tensor.max():.2f}", # type: ignore
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
tensor_f = tensor.float()
|
||||||
|
|
||||||
info_list.extend(
|
info_list.extend(
|
||||||
[
|
[
|
||||||
f"mean={tensor.float().mean():.2f}",
|
f"mean={tensor_f.mean():.2f}",
|
||||||
f"std={tensor.float().std():.2f}",
|
f"std={tensor_f.std():.2f}",
|
||||||
f"norm={norm(x=tensor.float()):.2f}",
|
f"norm={norm(x=tensor_f):.2f}",
|
||||||
f"grad={tensor.requires_grad}",
|
f"grad={tensor.requires_grad}",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue