From 31b5f80496ade854d2aff2ad4934be5183322679 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Thu, 12 Sep 2024 14:15:06 +0200 Subject: [PATCH] simplify implementation of load_from_safetensors --- pyproject.toml | 2 +- requirements.lock | 4 +++- src/refiners/fluxion/utils.py | 36 +++-------------------------------- 3 files changed, 7 insertions(+), 35 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0818e80..e05a707 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = [{ name = "The Finegrain Team", email = "bonjour@lagon.tech" }] license = { text = "MIT License" } dependencies = [ "torch>=2.1.1", - "safetensors>=0.4.0", + "safetensors>=0.4.5", "pillow>=10.4.0", "jaxtyping>=0.2.23", "packaging>=23.2", diff --git a/requirements.lock b/requirements.lock index e1c7f47..995769f 100644 --- a/requirements.lock +++ b/requirements.lock @@ -338,7 +338,7 @@ rpds-py==0.19.1 # via referencing s3transfer==0.10.2 # via boto3 -safetensors==0.4.3 +safetensors==0.4.5 # via diffusers # via refiners # via timm @@ -347,6 +347,8 @@ segment-anything-hq==0.3 # via refiners segment-anything-py==1.0.1 # via refiners +sentencepiece==0.2.0 + # via refiners sentry-sdk==2.12.0 # via wandb setproctitle==1.3.3 diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index d3dc0d4..f80589d 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -1,13 +1,12 @@ import warnings from pathlib import Path -from typing import Any, Iterable, Literal, TypeVar, cast +from typing import Any, Iterable, TypeVar, cast import torch from jaxtyping import Float from numpy import array, float32 from PIL import Image -from safetensors import safe_open as _safe_open # type: ignore -from safetensors.torch import save_file as _save_file # type: ignore +from safetensors.torch import load_file as _load_file, save_file as _save_file # type: ignore from torch import Tensor, device as Device, dtype as DType from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore @@ -186,34 +185,6 @@ def tensor_to_image(tensor: Tensor) -> Image.Image: return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType] -def safe_open( - path: Path | str, - framework: Literal["pytorch", "tensorflow", "flax", "numpy"], - device: Device | str = "cpu", -) -> dict[str, Tensor]: - """Open a SafeTensor file from disk. - - Args: - path: The path to the file. - framework: The framework used to save the file. - device: The device to use for the tensors. - - Returns: - The loaded tensors. - """ - framework_mapping = { - "pytorch": "pt", - "tensorflow": "tf", - "flax": "flax", - "numpy": "numpy", - } - return _safe_open( - str(path), - framework=framework_mapping[framework], - device=str(device), - ) # type: ignore - - def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]: """Load tensors from a file saved with `torch.save` from disk. @@ -247,8 +218,7 @@ def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dic Returns: The loaded tensors. """ - with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore - return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore + return _load_file(path, str(device)) def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None: