diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index deb0d46..47789a2 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -187,20 +187,26 @@ def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: def summarize_tensor(tensor: torch.Tensor, /) -> str: - return ( - "Tensor(" - + ", ".join( + info_list = [ + f"shape=({', '.join(map(str, tensor.shape))})", + f"dtype={str(object=tensor.dtype).removeprefix('torch.')}", + f"device={tensor.device}", + ] + if not tensor.is_complex(): + info_list.extend( [ - f"shape=({', '.join(map(str, tensor.shape))})", - f"dtype={str(object=tensor.dtype).removeprefix('torch.')}", - f"device={tensor.device}", f"min={tensor.min():.2f}", # type: ignore f"max={tensor.max():.2f}", # type: ignore - f"mean={tensor.mean():.2f}", - f"std={tensor.std():.2f}", - f"norm={norm(x=tensor):.2f}", - f"grad={tensor.requires_grad}", ] ) - + ")" + + info_list.extend( + [ + f"mean={tensor.float().mean():.2f}", + f"std={tensor.float().std():.2f}", + f"norm={norm(x=tensor.float()):.2f}", + f"grad={tensor.requires_grad}", + ] ) + + return "Tensor(" + ", ".join(info_list) + ")" diff --git a/tests/fluxion/test_utils.py b/tests/fluxion/test_utils.py index 8837550..1edb04d 100644 --- a/tests/fluxion/test_utils.py +++ b/tests/fluxion/test_utils.py @@ -7,7 +7,14 @@ from PIL import Image from torch import device as Device, dtype as DType from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore -from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, no_grad, tensor_to_image +from refiners.fluxion.utils import ( + gaussian_blur, + image_to_tensor, + manual_seed, + no_grad, + summarize_tensor, + tensor_to_image, +) @dataclass @@ -64,6 +71,15 @@ def test_tensor_to_image() -> None: assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA" +def test_summarize_tensor() -> None: + assert summarize_tensor(torch.zeros(1, 3, 512, 512).int()) + assert summarize_tensor(torch.zeros(1, 3, 512, 512).float()) + assert summarize_tensor(torch.zeros(1, 3, 512, 512).double()) + assert summarize_tensor(torch.complex(torch.zeros(1, 3, 512, 512), torch.zeros(1, 3, 512, 512))) + assert summarize_tensor(torch.zeros(1, 3, 512, 512).bfloat16()) + assert summarize_tensor(torch.zeros(1, 3, 512, 512).bool()) + + def test_no_grad() -> None: x = torch.randn(1, 1, requires_grad=True)