simplify implementation of load_from_safetensors

This commit is contained in:
Pierre Chapuis 2024-09-12 14:15:06 +02:00
parent 2c0174f50e
commit 31b5f80496
3 changed files with 7 additions and 35 deletions

View file

@ -6,7 +6,7 @@ authors = [{ name = "The Finegrain Team", email = "bonjour@lagon.tech" }]
license = { text = "MIT License" }
dependencies = [
"torch>=2.1.1",
"safetensors>=0.4.0",
"safetensors>=0.4.5",
"pillow>=10.4.0",
"jaxtyping>=0.2.23",
"packaging>=23.2",

View file

@ -338,7 +338,7 @@ rpds-py==0.19.1
# via referencing
s3transfer==0.10.2
# via boto3
safetensors==0.4.3
safetensors==0.4.5
# via diffusers
# via refiners
# via timm
@ -347,6 +347,8 @@ segment-anything-hq==0.3
# via refiners
segment-anything-py==1.0.1
# via refiners
sentencepiece==0.2.0
# via refiners
sentry-sdk==2.12.0
# via wandb
setproctitle==1.3.3

View file

@ -1,13 +1,12 @@
import warnings
from pathlib import Path
from typing import Any, Iterable, Literal, TypeVar, cast
from typing import Any, Iterable, TypeVar, cast
import torch
from jaxtyping import Float
from numpy import array, float32
from PIL import Image
from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore
from safetensors.torch import load_file as _load_file, save_file as _save_file # type: ignore
from torch import Tensor, device as Device, dtype as DType
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
@ -186,34 +185,6 @@ def tensor_to_image(tensor: Tensor) -> Image.Image:
return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType]
def safe_open(
path: Path | str,
framework: Literal["pytorch", "tensorflow", "flax", "numpy"],
device: Device | str = "cpu",
) -> dict[str, Tensor]:
"""Open a SafeTensor file from disk.
Args:
path: The path to the file.
framework: The framework used to save the file.
device: The device to use for the tensors.
Returns:
The loaded tensors.
"""
framework_mapping = {
"pytorch": "pt",
"tensorflow": "tf",
"flax": "flax",
"numpy": "numpy",
}
return _safe_open(
str(path),
framework=framework_mapping[framework],
device=str(device),
) # type: ignore
def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]:
"""Load tensors from a file saved with `torch.save` from disk.
@ -247,8 +218,7 @@ def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dic
Returns:
The loaded tensors.
"""
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore
return _load_file(path, str(device))
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None: