summarize_tensor: fix minor warning

Calling `tensor.float()` on a complex tensor raises a warning:

    UserWarning: Casting complex values to real discards the imaginary
    part (Triggered internally at ../aten/src/ATen/native/Copy.cpp:299.)

Follow up of #171
This commit is contained in:
Cédric Deltheil 2024-01-19 11:22:08 +01:00 committed by Cédric Deltheil
parent aa9f572611
commit fde61757fb

View file

@ -193,19 +193,22 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str:
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}", f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
f"device={tensor.device}", f"device={tensor.device}",
] ]
if not tensor.is_complex(): if tensor.is_complex():
tensor_f = tensor.real.float()
else:
info_list.extend( info_list.extend(
[ [
f"min={tensor.min():.2f}", # type: ignore f"min={tensor.min():.2f}", # type: ignore
f"max={tensor.max():.2f}", # type: ignore f"max={tensor.max():.2f}", # type: ignore
] ]
) )
tensor_f = tensor.float()
info_list.extend( info_list.extend(
[ [
f"mean={tensor.float().mean():.2f}", f"mean={tensor_f.mean():.2f}",
f"std={tensor.float().std():.2f}", f"std={tensor_f.std():.2f}",
f"norm={norm(x=tensor.float()):.2f}", f"norm={norm(x=tensor_f):.2f}",
f"grad={tensor.requires_grad}", f"grad={tensor.requires_grad}",
] ]
) )