mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
(doc/fluxion/utils) add/convert docstrings to mkdocstrings format
This commit is contained in:
parent
12a8dd6c85
commit
e79c2bdde5
|
@ -121,13 +121,21 @@ def images_to_tensor(
|
||||||
|
|
||||||
|
|
||||||
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
||||||
"""
|
"""Convert a PIL Image to a Tensor.
|
||||||
Convert a PIL Image to a Tensor.
|
|
||||||
|
|
||||||
If the image is in mode `RGB` the tensor will have shape `[3, H, W]`, otherwise
|
Args:
|
||||||
`[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`.
|
image: The image to convert.
|
||||||
|
device: The device to use for the tensor.
|
||||||
|
dtype: The dtype to use for the tensor.
|
||||||
|
|
||||||
Values are clamped to the range `[0, 1]`.
|
Returns:
|
||||||
|
The converted tensor.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the image is in mode `RGB` the tensor will have shape `[3, H, W]`,
|
||||||
|
otherwise `[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`.
|
||||||
|
|
||||||
|
Values are clamped to the range `[0, 1]`.
|
||||||
"""
|
"""
|
||||||
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)
|
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
@ -147,13 +155,19 @@ def tensor_to_images(tensor: Tensor) -> list[Image.Image]:
|
||||||
|
|
||||||
|
|
||||||
def tensor_to_image(tensor: Tensor) -> Image.Image:
|
def tensor_to_image(tensor: Tensor) -> Image.Image:
|
||||||
"""
|
"""Convert a Tensor to a PIL Image.
|
||||||
Convert a Tensor to a PIL Image.
|
|
||||||
|
|
||||||
The tensor must have shape `[1, channels, height, width]` where the number of
|
Args:
|
||||||
channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).
|
tensor: The tensor to convert.
|
||||||
|
|
||||||
Expected values are in the range `[0, 1]` and are clamped to this range.
|
Returns:
|
||||||
|
The converted image.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The tensor must have shape `[1, channels, height, width]` where the number of
|
||||||
|
channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).
|
||||||
|
|
||||||
|
Expected values are in the range `[0, 1]` and are clamped to this range.
|
||||||
"""
|
"""
|
||||||
assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
|
assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
|
||||||
num_channels = tensor.shape[1]
|
num_channels = tensor.shape[1]
|
||||||
|
@ -176,20 +190,38 @@ def safe_open(
|
||||||
framework: Literal["pytorch", "tensorflow", "flax", "numpy"],
|
framework: Literal["pytorch", "tensorflow", "flax", "numpy"],
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
) -> dict[str, Tensor]:
|
) -> dict[str, Tensor]:
|
||||||
|
"""Open a SafeTensor file from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The path to the file.
|
||||||
|
framework: The framework used to save the file.
|
||||||
|
device: The device to use for the tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The loaded tensors.
|
||||||
|
"""
|
||||||
framework_mapping = {
|
framework_mapping = {
|
||||||
"pytorch": "pt",
|
"pytorch": "pt",
|
||||||
"tensorflow": "tf",
|
"tensorflow": "tf",
|
||||||
"flax": "flax",
|
"flax": "flax",
|
||||||
"numpy": "numpy",
|
"numpy": "numpy",
|
||||||
}
|
}
|
||||||
return _safe_open(str(path), framework=framework_mapping[framework], device=str(device)) # type: ignore
|
return _safe_open(
|
||||||
|
str(path),
|
||||||
|
framework=framework_mapping[framework],
|
||||||
|
device=str(device),
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]:
|
def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]:
|
||||||
"""
|
"""Load tensors from a file saved with `torch.save` from disk.
|
||||||
Load tensors from a file saved with `torch.save` from disk using the `weights_only` mode
|
|
||||||
for additional safety (see `torch.load` for more details). Still, *only load data you trust* and
|
Note:
|
||||||
favor using `load_from_safetensors`.
|
This function uses the `weights_only` mode of `torch.load` for additional safety.
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
Still, **only load data you trust** and favor using
|
||||||
|
[`load_from_safetensors`](refiners.fluxion.utils.load_from_safetensors) instead.
|
||||||
"""
|
"""
|
||||||
# see https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
|
# see https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
|
@ -205,15 +237,41 @@ def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str,
|
||||||
|
|
||||||
|
|
||||||
def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]:
|
def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]:
|
||||||
|
"""Load tensors from a SafeTensor file from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The path to the file.
|
||||||
|
device: The device to use for the tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The loaded tensors.
|
||||||
|
"""
|
||||||
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
|
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
|
||||||
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore
|
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
|
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
|
||||||
|
"""Save tensors to a SafeTensor file on disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The path to the file.
|
||||||
|
tensors: The tensors to save.
|
||||||
|
metadata: The metadata to save.
|
||||||
|
"""
|
||||||
_save_file(tensors, path, metadata) # type: ignore
|
_save_file(tensors, path, metadata) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def summarize_tensor(tensor: torch.Tensor, /) -> str:
|
def summarize_tensor(tensor: torch.Tensor, /) -> str:
|
||||||
|
"""Summarize a tensor.
|
||||||
|
|
||||||
|
This helper function prints the shape, dtype, device, min, max, mean, std, norm and grad of a tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: The tensor to summarize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The summary string.
|
||||||
|
"""
|
||||||
info_list = [
|
info_list = [
|
||||||
f"shape=({', '.join(map(str, tensor.shape))})",
|
f"shape=({', '.join(map(str, tensor.shape))})",
|
||||||
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
|
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
|
||||||
|
|
Loading…
Reference in a new issue