mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-12 16:18:22 +00:00
Make summarize_tensor robust to non-float dtypes (#171)
This commit is contained in:
parent
ce0f9887a3
commit
c141091afc
|
@ -187,20 +187,26 @@ def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata:
|
||||||
|
|
||||||
|
|
||||||
def summarize_tensor(tensor: torch.Tensor, /) -> str:
|
def summarize_tensor(tensor: torch.Tensor, /) -> str:
|
||||||
return (
|
info_list = [
|
||||||
"Tensor("
|
f"shape=({', '.join(map(str, tensor.shape))})",
|
||||||
+ ", ".join(
|
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
|
||||||
|
f"device={tensor.device}",
|
||||||
|
]
|
||||||
|
if not tensor.is_complex():
|
||||||
|
info_list.extend(
|
||||||
[
|
[
|
||||||
f"shape=({', '.join(map(str, tensor.shape))})",
|
|
||||||
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
|
|
||||||
f"device={tensor.device}",
|
|
||||||
f"min={tensor.min():.2f}", # type: ignore
|
f"min={tensor.min():.2f}", # type: ignore
|
||||||
f"max={tensor.max():.2f}", # type: ignore
|
f"max={tensor.max():.2f}", # type: ignore
|
||||||
f"mean={tensor.mean():.2f}",
|
|
||||||
f"std={tensor.std():.2f}",
|
|
||||||
f"norm={norm(x=tensor):.2f}",
|
|
||||||
f"grad={tensor.requires_grad}",
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
+ ")"
|
|
||||||
|
info_list.extend(
|
||||||
|
[
|
||||||
|
f"mean={tensor.float().mean():.2f}",
|
||||||
|
f"std={tensor.float().std():.2f}",
|
||||||
|
f"norm={norm(x=tensor.float()):.2f}",
|
||||||
|
f"grad={tensor.requires_grad}",
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return "Tensor(" + ", ".join(info_list) + ")"
|
||||||
|
|
|
@ -7,7 +7,14 @@ from PIL import Image
|
||||||
from torch import device as Device, dtype as DType
|
from torch import device as Device, dtype as DType
|
||||||
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
|
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
|
||||||
|
|
||||||
from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, no_grad, tensor_to_image
|
from refiners.fluxion.utils import (
|
||||||
|
gaussian_blur,
|
||||||
|
image_to_tensor,
|
||||||
|
manual_seed,
|
||||||
|
no_grad,
|
||||||
|
summarize_tensor,
|
||||||
|
tensor_to_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -64,6 +71,15 @@ def test_tensor_to_image() -> None:
|
||||||
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"
|
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"
|
||||||
|
|
||||||
|
|
||||||
|
def test_summarize_tensor() -> None:
|
||||||
|
assert summarize_tensor(torch.zeros(1, 3, 512, 512).int())
|
||||||
|
assert summarize_tensor(torch.zeros(1, 3, 512, 512).float())
|
||||||
|
assert summarize_tensor(torch.zeros(1, 3, 512, 512).double())
|
||||||
|
assert summarize_tensor(torch.complex(torch.zeros(1, 3, 512, 512), torch.zeros(1, 3, 512, 512)))
|
||||||
|
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bfloat16())
|
||||||
|
assert summarize_tensor(torch.zeros(1, 3, 512, 512).bool())
|
||||||
|
|
||||||
|
|
||||||
def test_no_grad() -> None:
|
def test_no_grad() -> None:
|
||||||
x = torch.randn(1, 1, requires_grad=True)
|
x = torch.randn(1, 1, requires_grad=True)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue