From fde61757fb5432b19c66d5206995e1a330635bad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Fri, 19 Jan 2024 11:22:08 +0100 Subject: [PATCH] summarize_tensor: fix minor warning Calling `tensor.float()` on a complex tensor raises a warning: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at ../aten/src/ATen/native/Copy.cpp:299.) Follow up of #171 --- src/refiners/fluxion/utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 09d9c06..6052ebb 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -193,19 +193,22 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str: f"dtype={str(object=tensor.dtype).removeprefix('torch.')}", f"device={tensor.device}", ] - if not tensor.is_complex(): + if tensor.is_complex(): + tensor_f = tensor.real.float() + else: info_list.extend( [ f"min={tensor.min():.2f}", # type: ignore f"max={tensor.max():.2f}", # type: ignore ] ) + tensor_f = tensor.float() info_list.extend( [ - f"mean={tensor.float().mean():.2f}", - f"std={tensor.float().std():.2f}", - f"norm={norm(x=tensor.float()):.2f}", + f"mean={tensor_f.mean():.2f}", + f"std={tensor_f.std():.2f}", + f"norm={norm(x=tensor_f):.2f}", f"grad={tensor.requires_grad}", ] )