diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 68b97c2..38a3e7d 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -121,13 +121,21 @@ def images_to_tensor( def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor: - """ - Convert a PIL Image to a Tensor. + """Convert a PIL Image to a Tensor. - If the image is in mode `RGB` the tensor will have shape `[3, H, W]`, otherwise - `[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`. + Args: + image: The image to convert. + device: The device to use for the tensor. + dtype: The dtype to use for the tensor. - Values are clamped to the range `[0, 1]`. + Returns: + The converted tensor. + + Note: + If the image is in mode `RGB` the tensor will have shape `[3, H, W]`, + otherwise `[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`. + + Values are clamped to the range `[0, 1]`. """ image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype) @@ -147,13 +155,19 @@ def tensor_to_images(tensor: Tensor) -> list[Image.Image]: def tensor_to_image(tensor: Tensor) -> Image.Image: - """ - Convert a Tensor to a PIL Image. + """Convert a Tensor to a PIL Image. - The tensor must have shape `[1, channels, height, width]` where the number of - channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA). + Args: + tensor: The tensor to convert. - Expected values are in the range `[0, 1]` and are clamped to this range. + Returns: + The converted image. + + Note: + The tensor must have shape `[1, channels, height, width]` where the number of + channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA). + + Expected values are in the range `[0, 1]` and are clamped to this range. """ assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}" num_channels = tensor.shape[1] @@ -176,20 +190,38 @@ def safe_open( framework: Literal["pytorch", "tensorflow", "flax", "numpy"], device: Device | str = "cpu", ) -> dict[str, Tensor]: + """Open a SafeTensor file from disk. + + Args: + path: The path to the file. + framework: The framework used to save the file. + device: The device to use for the tensors. + + Returns: + The loaded tensors. + """ framework_mapping = { "pytorch": "pt", "tensorflow": "tf", "flax": "flax", "numpy": "numpy", } - return _safe_open(str(path), framework=framework_mapping[framework], device=str(device)) # type: ignore + return _safe_open( + str(path), + framework=framework_mapping[framework], + device=str(device), + ) # type: ignore def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]: - """ - Load tensors from a file saved with `torch.save` from disk using the `weights_only` mode - for additional safety (see `torch.load` for more details). Still, *only load data you trust* and - favor using `load_from_safetensors`. + """Load tensors from a file saved with `torch.save` from disk. + + Note: + This function uses the `weights_only` mode of `torch.load` for additional safety. + + Warning: + Still, **only load data you trust** and favor using + [`load_from_safetensors`](refiners.fluxion.utils.load_from_safetensors) instead. """ # see https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560 with warnings.catch_warnings(): @@ -205,15 +237,41 @@ def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]: + """Load tensors from a SafeTensor file from disk. + + Args: + path: The path to the file. + device: The device to use for the tensors. + + Returns: + The loaded tensors. + """ with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None: + """Save tensors to a SafeTensor file on disk. + + Args: + path: The path to the file. + tensors: The tensors to save. + metadata: The metadata to save. + """ _save_file(tensors, path, metadata) # type: ignore def summarize_tensor(tensor: torch.Tensor, /) -> str: + """Summarize a tensor. + + This helper function prints the shape, dtype, device, min, max, mean, std, norm and grad of a tensor. + + Args: + tensor: The tensor to summarize. + + Returns: + The summary string. + """ info_list = [ f"shape=({', '.join(map(str, tensor.shape))})", f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",