simplify implementation of load_from_safetensors

This commit is contained in:
Pierre Chapuis 2024-09-12 14:15:06 +02:00
parent 2c0174f50e
commit 31b5f80496
3 changed files with 7 additions and 35 deletions

View file

@ -6,7 +6,7 @@ authors = [{ name = "The Finegrain Team", email = "bonjour@lagon.tech" }]
license = { text = "MIT License" } license = { text = "MIT License" }
dependencies = [ dependencies = [
"torch>=2.1.1", "torch>=2.1.1",
"safetensors>=0.4.0", "safetensors>=0.4.5",
"pillow>=10.4.0", "pillow>=10.4.0",
"jaxtyping>=0.2.23", "jaxtyping>=0.2.23",
"packaging>=23.2", "packaging>=23.2",

View file

@ -338,7 +338,7 @@ rpds-py==0.19.1
# via referencing # via referencing
s3transfer==0.10.2 s3transfer==0.10.2
# via boto3 # via boto3
safetensors==0.4.3 safetensors==0.4.5
# via diffusers # via diffusers
# via refiners # via refiners
# via timm # via timm
@ -347,6 +347,8 @@ segment-anything-hq==0.3
# via refiners # via refiners
segment-anything-py==1.0.1 segment-anything-py==1.0.1
# via refiners # via refiners
sentencepiece==0.2.0
# via refiners
sentry-sdk==2.12.0 sentry-sdk==2.12.0
# via wandb # via wandb
setproctitle==1.3.3 setproctitle==1.3.3

View file

@ -1,13 +1,12 @@
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, Literal, TypeVar, cast from typing import Any, Iterable, TypeVar, cast
import torch import torch
from jaxtyping import Float from jaxtyping import Float
from numpy import array, float32 from numpy import array, float32
from PIL import Image from PIL import Image
from safetensors import safe_open as _safe_open # type: ignore from safetensors.torch import load_file as _load_file, save_file as _save_file # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
@ -186,34 +185,6 @@ def tensor_to_image(tensor: Tensor) -> Image.Image:
return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType] return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType]
def safe_open(
path: Path | str,
framework: Literal["pytorch", "tensorflow", "flax", "numpy"],
device: Device | str = "cpu",
) -> dict[str, Tensor]:
"""Open a SafeTensor file from disk.
Args:
path: The path to the file.
framework: The framework used to save the file.
device: The device to use for the tensors.
Returns:
The loaded tensors.
"""
framework_mapping = {
"pytorch": "pt",
"tensorflow": "tf",
"flax": "flax",
"numpy": "numpy",
}
return _safe_open(
str(path),
framework=framework_mapping[framework],
device=str(device),
) # type: ignore
def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]: def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]:
"""Load tensors from a file saved with `torch.save` from disk. """Load tensors from a file saved with `torch.save` from disk.
@ -247,8 +218,7 @@ def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dic
Returns: Returns:
The loaded tensors. The loaded tensors.
""" """
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore return _load_file(path, str(device))
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None: def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None: