mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
simplify implementation of load_from_safetensors
This commit is contained in:
parent
2c0174f50e
commit
31b5f80496
|
@ -6,7 +6,7 @@ authors = [{ name = "The Finegrain Team", email = "bonjour@lagon.tech" }]
|
||||||
license = { text = "MIT License" }
|
license = { text = "MIT License" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"torch>=2.1.1",
|
"torch>=2.1.1",
|
||||||
"safetensors>=0.4.0",
|
"safetensors>=0.4.5",
|
||||||
"pillow>=10.4.0",
|
"pillow>=10.4.0",
|
||||||
"jaxtyping>=0.2.23",
|
"jaxtyping>=0.2.23",
|
||||||
"packaging>=23.2",
|
"packaging>=23.2",
|
||||||
|
|
|
@ -338,7 +338,7 @@ rpds-py==0.19.1
|
||||||
# via referencing
|
# via referencing
|
||||||
s3transfer==0.10.2
|
s3transfer==0.10.2
|
||||||
# via boto3
|
# via boto3
|
||||||
safetensors==0.4.3
|
safetensors==0.4.5
|
||||||
# via diffusers
|
# via diffusers
|
||||||
# via refiners
|
# via refiners
|
||||||
# via timm
|
# via timm
|
||||||
|
@ -347,6 +347,8 @@ segment-anything-hq==0.3
|
||||||
# via refiners
|
# via refiners
|
||||||
segment-anything-py==1.0.1
|
segment-anything-py==1.0.1
|
||||||
# via refiners
|
# via refiners
|
||||||
|
sentencepiece==0.2.0
|
||||||
|
# via refiners
|
||||||
sentry-sdk==2.12.0
|
sentry-sdk==2.12.0
|
||||||
# via wandb
|
# via wandb
|
||||||
setproctitle==1.3.3
|
setproctitle==1.3.3
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Iterable, Literal, TypeVar, cast
|
from typing import Any, Iterable, TypeVar, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from numpy import array, float32
|
from numpy import array, float32
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors import safe_open as _safe_open # type: ignore
|
from safetensors.torch import load_file as _load_file, save_file as _save_file # type: ignore
|
||||||
from safetensors.torch import save_file as _save_file # type: ignore
|
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
|
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
|
||||||
|
|
||||||
|
@ -186,34 +185,6 @@ def tensor_to_image(tensor: Tensor) -> Image.Image:
|
||||||
return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType]
|
return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType]
|
||||||
|
|
||||||
|
|
||||||
def safe_open(
|
|
||||||
path: Path | str,
|
|
||||||
framework: Literal["pytorch", "tensorflow", "flax", "numpy"],
|
|
||||||
device: Device | str = "cpu",
|
|
||||||
) -> dict[str, Tensor]:
|
|
||||||
"""Open a SafeTensor file from disk.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: The path to the file.
|
|
||||||
framework: The framework used to save the file.
|
|
||||||
device: The device to use for the tensors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The loaded tensors.
|
|
||||||
"""
|
|
||||||
framework_mapping = {
|
|
||||||
"pytorch": "pt",
|
|
||||||
"tensorflow": "tf",
|
|
||||||
"flax": "flax",
|
|
||||||
"numpy": "numpy",
|
|
||||||
}
|
|
||||||
return _safe_open(
|
|
||||||
str(path),
|
|
||||||
framework=framework_mapping[framework],
|
|
||||||
device=str(device),
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]:
|
def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]:
|
||||||
"""Load tensors from a file saved with `torch.save` from disk.
|
"""Load tensors from a file saved with `torch.save` from disk.
|
||||||
|
|
||||||
|
@ -247,8 +218,7 @@ def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dic
|
||||||
Returns:
|
Returns:
|
||||||
The loaded tensors.
|
The loaded tensors.
|
||||||
"""
|
"""
|
||||||
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
|
return _load_file(path, str(device))
|
||||||
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
|
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
|
||||||
|
|
Loading…
Reference in a new issue