Add load_tensors utils in fluxion

This commit is contained in:
limiteinductive 2024-01-19 16:37:01 +01:00 committed by Benjamin Trom
parent 91aea9b7ff
commit ed3621362f
9 changed files with 66 additions and 16 deletions

View file

@ -4,7 +4,7 @@ from typing import Any
import torch
from refiners.fluxion.utils import save_to_safetensors
from refiners.fluxion.utils import load_tensors, save_to_safetensors
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet
# Running:
@ -66,7 +66,7 @@ def main() -> None:
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"
weights: dict[str, Any] = torch.load(f=args.source_path, map_location="cpu") # type: ignore
weights: dict[str, Any] = load_tensors(args.source_path, device="cpu")
assert isinstance(weights, dict)
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]

View file

@ -3,7 +3,7 @@ from pathlib import Path
import torch
from refiners.fluxion.utils import save_to_safetensors
from refiners.fluxion.utils import load_tensors, save_to_safetensors
def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
@ -148,7 +148,7 @@ def main() -> None:
parser.add_argument("--half", action="store_true", dest="half")
args = parser.parse_args()
weights = torch.load(args.source_path) # type: ignore
weights = load_tensors(args.source_path)
convert_dinov2_facebook(weights)
if args.half:
weights = {key: value.half() for key, value in weights.items()}

View file

@ -5,6 +5,7 @@ import torch
from torch import nn
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import load_tensors
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
try:
@ -27,7 +28,7 @@ class Args(argparse.Namespace):
def setup_converter(args: Args) -> ModelConverter:
source = Generator(3, 1, 3)
source.load_state_dict(state_dict=torch.load(f=args.source_path, map_location="cpu")) # type: ignore
source.load_state_dict(state_dict=load_tensors(args.source_path))
source.eval()
target = InformativeDrawings()
x = torch.randn(1, 3, 512, 512)

View file

@ -10,7 +10,7 @@ from torch import Tensor
import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import manual_seed, save_to_safetensors
from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors
from refiners.foundationals.segment_anything.image_encoder import SAMViTH
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
@ -245,7 +245,7 @@ def main() -> None:
args = parser.parse_args(namespace=Args())
sam_h = build_sam_vit_h() # type: ignore
sam_h.load_state_dict(state_dict=torch.load(f=args.source_path)) # type: ignore
sam_h.load_state_dict(state_dict=load_tensors(args.source_path))
vit_state_dict = convert_vit(vit=sam_h.image_encoder)
mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder)

View file

@ -1,5 +1,6 @@
import warnings
from pathlib import Path
from typing import Any, Iterable, Literal, TypeVar
from typing import Any, Iterable, Literal, TypeVar, cast
import torch
from jaxtyping import Float
@ -173,6 +174,25 @@ def safe_open(
return _safe_open(str(path), framework=framework_mapping[framework], device=str(device)) # type: ignore
def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]:
"""
Load tensors from a file saved with `torch.save` from disk using the `weights_only` mode
for additional safety (see `torch.load` for more details). Still, *only load data you trust* and
favor using `load_from_safetensors`.
"""
# see https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
tensors = torch.load(path, map_location=device, weights_only=True) # type: ignore
assert isinstance(tensors, dict) and all(
isinstance(key, str) and isinstance(value, Tensor)
for key, value in tensors.items() # type: ignore
), "Invalid tensor file, expected a dict[str, Tensor]"
return cast(dict[str, Tensor], tensors)
def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]:
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore

View file

@ -6,7 +6,7 @@ import pytest
import torch
from PIL import Image
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion import (
SD1ControlnetAdapter,
@ -199,7 +199,7 @@ def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Im
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
tensors = torch.load(weights_path) # type: ignore
tensors = load_tensors(weights_path)
return expected_image, tensors
@ -282,7 +282,7 @@ def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.
@pytest.fixture
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]
@pytest.fixture(scope="module")

View file

@ -1,4 +1,6 @@
import pickle
from dataclasses import dataclass
from pathlib import Path
from warnings import warn
import pytest
@ -7,9 +9,11 @@ from PIL import Image
from torch import device as Device, dtype as DType
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
from refiners.fluxion import layers as fl
from refiners.fluxion.utils import (
gaussian_blur,
image_to_tensor,
load_tensors,
manual_seed,
no_grad,
summarize_tensor,
@ -95,3 +99,28 @@ def test_no_grad() -> None:
w = x + 1
assert w.requires_grad
def test_load_tensors_valid_pickle(tmp_path: Path) -> None:
pickle_path = tmp_path / "valid.pickle"
tensors = {"easy-as.weight": torch.tensor([1.0, 2.0, 3.0])}
torch.save(tensors, pickle_path) # type: ignore
loaded_tensor = load_tensors(pickle_path)
assert torch.equal(loaded_tensor["easy-as.weight"], tensors["easy-as.weight"])
tensors = {"easy-as.weight": torch.tensor([1, 2, 3]), "hello": "world"}
torch.save(tensors, pickle_path) # type: ignore
with pytest.raises(AssertionError):
loaded_tensor = load_tensors(pickle_path)
def test_load_tensors_invalid_pickle(tmp_path: Path) -> None:
invalid_pickle_path = tmp_path / "invalid.pickle"
model = fl.Chain(fl.Linear(1, 1))
torch.save(model, invalid_pickle_path) # type: ignore
with pytest.raises(
pickle.UnpicklingError,
):
load_tensors(invalid_pickle_path)

View file

@ -7,7 +7,7 @@ import transformers # type: ignore
from diffusers import StableDiffusionPipeline # type: ignore
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import load_from_safetensors, no_grad
from refiners.fluxion.utils import load_from_safetensors, load_tensors, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
@ -76,12 +76,12 @@ def prompt(request: pytest.FixtureRequest):
@pytest.fixture(scope="module")
def gta5_artwork_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]
@pytest.fixture(scope="module")
def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return torch.load(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"] # type: ignore
return load_tensors(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"]
def test_tokenizer_with_special_character():

View file

@ -18,7 +18,7 @@ from torch import Tensor
from refiners.fluxion import manual_seed
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import image_to_tensor, no_grad
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
from refiners.foundationals.segment_anything.model import SegmentAnythingH
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
@ -69,7 +69,7 @@ def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> F
from segment_anything import build_sam_vit_h # type: ignore
sam_h = cast(FacebookSAM, build_sam_vit_h())
sam_h.load_state_dict(state_dict=torch.load(f=facebook_sam_h_weights)) # type: ignore
sam_h.load_state_dict(state_dict=load_tensors(facebook_sam_h_weights))
return sam_h.to(device=test_device)