diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 09d9c06..6052ebb 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -193,19 +193,22 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str: f"dtype={str(object=tensor.dtype).removeprefix('torch.')}", f"device={tensor.device}", ] - if not tensor.is_complex(): + if tensor.is_complex(): + tensor_f = tensor.real.float() + else: info_list.extend( [ f"min={tensor.min():.2f}", # type: ignore f"max={tensor.max():.2f}", # type: ignore ] ) + tensor_f = tensor.float() info_list.extend( [ - f"mean={tensor.float().mean():.2f}", - f"std={tensor.float().std():.2f}", - f"norm={norm(x=tensor.float()):.2f}", + f"mean={tensor_f.mean():.2f}", + f"std={tensor_f.std():.2f}", + f"norm={norm(x=tensor_f):.2f}", f"grad={tensor.requires_grad}", ] )