Skip to content

Utils

image_to_tensor

image_to_tensor(
    image: Image,
    device: device | str | None = None,
    dtype: dtype | None = None,
) -> Tensor

Convert a PIL Image to a Tensor.

Parameters:

Name Type Description Default
image Image

The image to convert.

required
device device | str | None

The device to use for the tensor.

None
dtype dtype | None

The dtype to use for the tensor.

None

Returns:

Type Description
Tensor

The converted tensor.

Note

If the image is in mode RGB the tensor will have shape [3, H, W], otherwise [1, H, W] for mode L (grayscale) or [4, H, W] for mode RGBA.

Values are normalized to the range [0, 1].

Source code in src/refiners/fluxion/utils.py
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
    """Convert a PIL Image to a Tensor.

    Args:
        image: The image to convert.
        device: The device to use for the tensor.
        dtype: The dtype to use for the tensor.

    Returns:
        The converted tensor.

    Note:
        If the image is in mode `RGB` the tensor will have shape `[3, H, W]`,
        otherwise `[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`.

        Values are normalized to the range `[0, 1]`.
    """
    image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)

    assert isinstance(image.mode, str)  # type: ignore
    match image.mode:
        case "L":
            image_tensor = image_tensor.unsqueeze(0)
        case "RGBA" | "RGB":
            image_tensor = image_tensor.permute(2, 0, 1)
        case _:
            raise ValueError(f"Unsupported image mode: {image.mode}")

    return image_tensor.unsqueeze(0)

load_from_safetensors

load_from_safetensors(
    path: Path | str, device: device | str = "cpu"
) -> dict[str, Tensor]

Load tensors from a SafeTensor file from disk.

Parameters:

Name Type Description Default
path Path | str

The path to the file.

required
device device | str

The device to use for the tensors.

'cpu'

Returns:

Type Description
dict[str, Tensor]

The loaded tensors.

Source code in src/refiners/fluxion/utils.py
def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]:
    """Load tensors from a SafeTensor file from disk.

    Args:
        path: The path to the file.
        device: The device to use for the tensors.

    Returns:
        The loaded tensors.
    """
    return _load_file(path, str(device))

load_tensors

load_tensors(
    path: Path | str, /, device: device | str = "cpu"
) -> dict[str, Tensor]

Load tensors from a file saved with torch.save from disk.

Note

This function uses the weights_only mode of torch.load for additional safety.

Warning

Still, only load data you trust and favor using load_from_safetensors instead.

Source code in src/refiners/fluxion/utils.py
def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]:
    """Load tensors from a file saved with `torch.save` from disk.

    Note:
        This function uses the `weights_only` mode of `torch.load` for additional safety.

    Warning:
        Still, **only load data you trust** and favor using
        [`load_from_safetensors`][refiners.fluxion.utils.load_from_safetensors] instead.
    """
    # see https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
        tensors = torch.load(path, map_location=device, weights_only=True)  # type: ignore

    assert isinstance(tensors, dict) and all(
        isinstance(key, str) and isinstance(value, Tensor)
        for key, value in tensors.items()  # type: ignore
    ), "Invalid tensor file, expected a dict[str, Tensor]"

    return cast(dict[str, Tensor], tensors)

save_to_safetensors

save_to_safetensors(
    path: Path | str,
    tensors: dict[str, Tensor],
    metadata: dict[str, str] | None = None,
) -> None

Save tensors to a SafeTensor file on disk.

Parameters:

Name Type Description Default
path Path | str

The path to the file.

required
tensors dict[str, Tensor]

The tensors to save.

required
metadata dict[str, str] | None

The metadata to save.

None
Source code in src/refiners/fluxion/utils.py
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
    """Save tensors to a SafeTensor file on disk.

    Args:
        path: The path to the file.
        tensors: The tensors to save.
        metadata: The metadata to save.
    """
    _save_file(tensors, path, metadata)  # type: ignore

str_to_dtype

str_to_dtype(dtype: str) -> dtype

Converts a string dtype to a torch.dtype.

See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype

Source code in src/refiners/fluxion/utils.py
def str_to_dtype(dtype: str) -> torch.dtype:
    """Converts a string dtype to a torch.dtype.

    See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype
    """
    match dtype.lower():
        case "float32" | "float":
            return torch.float32
        case "float64" | "double":
            return torch.float64
        case "complex64" | "cfloat":
            return torch.complex64
        case "complex128" | "cdouble":
            return torch.complex128
        case "float16" | "half":
            return torch.float16
        case "bfloat16":
            return torch.bfloat16
        case "uint8":
            return torch.uint8
        case "int8":
            return torch.int8
        case "int16" | "short":
            return torch.int16
        case "int32" | "int":
            return torch.int32
        case "int64" | "long":
            return torch.int64
        case "bool":
            return torch.bool
        case _:
            raise ValueError(f"Unknown dtype: {dtype}")

summarize_tensor

summarize_tensor(tensor: Tensor) -> str

Summarize a tensor.

This helper function prints the shape, dtype, device, min, max, mean, std, norm and grad of a tensor.

Parameters:

Name Type Description Default
tensor Tensor

The tensor to summarize.

required

Returns:

Type Description
str

The summary string.

Source code in src/refiners/fluxion/utils.py
def summarize_tensor(tensor: torch.Tensor, /) -> str:
    """Summarize a tensor.

    This helper function prints the shape, dtype, device, min, max, mean, std, norm and grad of a tensor.

    Args:
        tensor: The tensor to summarize.

    Returns:
        The summary string.
    """
    info_list = [
        f"shape=({', '.join(map(str, tensor.shape))})",
        f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
        f"device={tensor.device}",
    ]
    if tensor.is_complex():
        tensor_f = tensor.real.float()
    else:
        if tensor.numel() > 0:
            info_list.extend(
                [
                    f"min={tensor.min():.2f}",  # type: ignore
                    f"max={tensor.max():.2f}",  # type: ignore
                ]
            )
        tensor_f = tensor.float()

    info_list.extend(
        [
            f"mean={tensor_f.mean():.2f}",
            f"std={tensor_f.std():.2f}",
            f"norm={norm(x=tensor_f):.2f}",
            f"grad={tensor.requires_grad}",
        ]
    )

    return "Tensor(" + ", ".join(info_list) + ")"

tensor_to_image

tensor_to_image(tensor: Tensor) -> Image

Convert a Tensor to a PIL Image.

Parameters:

Name Type Description Default
tensor Tensor

The tensor to convert.

required

Returns:

Type Description
Image

The converted image.

Note

The tensor must have shape [1, channels, height, width] where the number of channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).

Expected values are in the range [0, 1] and are clamped to this range.

Source code in src/refiners/fluxion/utils.py
def tensor_to_image(tensor: Tensor) -> Image.Image:
    """Convert a Tensor to a PIL Image.

    Args:
        tensor: The tensor to convert.

    Returns:
        The converted image.

    Note:
        The tensor must have shape `[1, channels, height, width]` where the number of
        channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).

        Expected values are in the range `[0, 1]` and are clamped to this range.
    """
    assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
    num_channels = tensor.shape[1]
    tensor = tensor.clamp(0, 1).squeeze(0)
    tensor = tensor.to(torch.float32)  # to avoid numpy error with bfloat16

    match num_channels:
        case 1:
            tensor = tensor.squeeze(0)
        case 3 | 4:
            tensor = tensor.permute(1, 2, 0)
        case _:
            raise ValueError(f"Unsupported number of channels: {num_channels}")

    return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8"))  # type: ignore[reportUnknownType]