From ed3621362ff0b9087fd644728be1258fc3472e72 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Fri, 19 Jan 2024 16:37:01 +0100 Subject: [PATCH] Add load_tensors utils in fluxion --- .../convert_diffusers_ip_adapter.py | 4 +-- scripts/conversion/convert_dinov2.py | 4 +-- .../convert_informative_drawings.py | 3 +- .../conversion/convert_segment_anything.py | 4 +-- src/refiners/fluxion/utils.py | 22 +++++++++++++- tests/e2e/test_diffusion.py | 6 ++-- tests/fluxion/test_utils.py | 29 +++++++++++++++++++ tests/foundationals/clip/test_concepts.py | 6 ++-- .../segment_anything/test_sam.py | 4 +-- 9 files changed, 66 insertions(+), 16 deletions(-) diff --git a/scripts/conversion/convert_diffusers_ip_adapter.py b/scripts/conversion/convert_diffusers_ip_adapter.py index 8282db1..98d35e4 100644 --- a/scripts/conversion/convert_diffusers_ip_adapter.py +++ b/scripts/conversion/convert_diffusers_ip_adapter.py @@ -4,7 +4,7 @@ from typing import Any import torch -from refiners.fluxion.utils import save_to_safetensors +from refiners.fluxion.utils import load_tensors, save_to_safetensors from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet # Running: @@ -66,7 +66,7 @@ def main() -> None: if args.output_path is None: args.output_path = f"{Path(args.source_path).stem}.safetensors" - weights: dict[str, Any] = torch.load(f=args.source_path, map_location="cpu") # type: ignore + weights: dict[str, Any] = load_tensors(args.source_path, device="cpu") assert isinstance(weights, dict) assert sorted(weights.keys()) == ["image_proj", "ip_adapter"] diff --git a/scripts/conversion/convert_dinov2.py b/scripts/conversion/convert_dinov2.py index 0d71d47..3d51ce1 100644 --- a/scripts/conversion/convert_dinov2.py +++ b/scripts/conversion/convert_dinov2.py @@ -3,7 +3,7 @@ from pathlib import Path import torch -from refiners.fluxion.utils import save_to_safetensors +from refiners.fluxion.utils import load_tensors, save_to_safetensors def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None: @@ -148,7 +148,7 @@ def main() -> None: parser.add_argument("--half", action="store_true", dest="half") args = parser.parse_args() - weights = torch.load(args.source_path) # type: ignore + weights = load_tensors(args.source_path) convert_dinov2_facebook(weights) if args.half: weights = {key: value.half() for key, value in weights.items()} diff --git a/scripts/conversion/convert_informative_drawings.py b/scripts/conversion/convert_informative_drawings.py index cb20f36..d163472 100644 --- a/scripts/conversion/convert_informative_drawings.py +++ b/scripts/conversion/convert_informative_drawings.py @@ -5,6 +5,7 @@ import torch from torch import nn from refiners.fluxion.model_converter import ModelConverter +from refiners.fluxion.utils import load_tensors from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings try: @@ -27,7 +28,7 @@ class Args(argparse.Namespace): def setup_converter(args: Args) -> ModelConverter: source = Generator(3, 1, 3) - source.load_state_dict(state_dict=torch.load(f=args.source_path, map_location="cpu")) # type: ignore + source.load_state_dict(state_dict=load_tensors(args.source_path)) source.eval() target = InformativeDrawings() x = torch.randn(1, 3, 512, 512) diff --git a/scripts/conversion/convert_segment_anything.py b/scripts/conversion/convert_segment_anything.py index 14ba2ef..c41c2b8 100644 --- a/scripts/conversion/convert_segment_anything.py +++ b/scripts/conversion/convert_segment_anything.py @@ -10,7 +10,7 @@ from torch import Tensor import refiners.fluxion.layers as fl from refiners.fluxion.model_converter import ModelConverter -from refiners.fluxion.utils import manual_seed, save_to_safetensors +from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors from refiners.foundationals.segment_anything.image_encoder import SAMViTH from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder @@ -245,7 +245,7 @@ def main() -> None: args = parser.parse_args(namespace=Args()) sam_h = build_sam_vit_h() # type: ignore - sam_h.load_state_dict(state_dict=torch.load(f=args.source_path)) # type: ignore + sam_h.load_state_dict(state_dict=load_tensors(args.source_path)) vit_state_dict = convert_vit(vit=sam_h.image_encoder) mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index dedf8e1..80dd160 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -1,5 +1,6 @@ +import warnings from pathlib import Path -from typing import Any, Iterable, Literal, TypeVar +from typing import Any, Iterable, Literal, TypeVar, cast import torch from jaxtyping import Float @@ -173,6 +174,25 @@ def safe_open( return _safe_open(str(path), framework=framework_mapping[framework], device=str(device)) # type: ignore +def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]: + """ + Load tensors from a file saved with `torch.save` from disk using the `weights_only` mode + for additional safety (see `torch.load` for more details). Still, *only load data you trust* and + favor using `load_from_safetensors`. + """ + # see https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated") + tensors = torch.load(path, map_location=device, weights_only=True) # type: ignore + + assert isinstance(tensors, dict) and all( + isinstance(key, str) and isinstance(value, Tensor) + for key, value in tensors.items() # type: ignore + ), "Invalid tensor file, expected a dict[str, Tensor]" + + return cast(dict[str, Tensor], tensors) + + def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]: with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index bb3ec4a..bdfb048 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -6,7 +6,7 @@ import pytest import torch from PIL import Image -from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad +from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.latent_diffusion import ( SD1ControlnetAdapter, @@ -199,7 +199,7 @@ def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Im warn(f"could not find weights at {weights_path}, skipping") pytest.skip(allow_module_level=True) - tensors = torch.load(weights_path) # type: ignore + tensors = load_tensors(weights_path) return expected_image, tensors @@ -282,7 +282,7 @@ def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image. @pytest.fixture def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor: - return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")[""] # type: ignore + return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")[""] @pytest.fixture(scope="module") diff --git a/tests/fluxion/test_utils.py b/tests/fluxion/test_utils.py index 6165eb7..c3f6663 100644 --- a/tests/fluxion/test_utils.py +++ b/tests/fluxion/test_utils.py @@ -1,4 +1,6 @@ +import pickle from dataclasses import dataclass +from pathlib import Path from warnings import warn import pytest @@ -7,9 +9,11 @@ from PIL import Image from torch import device as Device, dtype as DType from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore +from refiners.fluxion import layers as fl from refiners.fluxion.utils import ( gaussian_blur, image_to_tensor, + load_tensors, manual_seed, no_grad, summarize_tensor, @@ -95,3 +99,28 @@ def test_no_grad() -> None: w = x + 1 assert w.requires_grad + + +def test_load_tensors_valid_pickle(tmp_path: Path) -> None: + pickle_path = tmp_path / "valid.pickle" + + tensors = {"easy-as.weight": torch.tensor([1.0, 2.0, 3.0])} + torch.save(tensors, pickle_path) # type: ignore + loaded_tensor = load_tensors(pickle_path) + assert torch.equal(loaded_tensor["easy-as.weight"], tensors["easy-as.weight"]) + + tensors = {"easy-as.weight": torch.tensor([1, 2, 3]), "hello": "world"} + torch.save(tensors, pickle_path) # type: ignore + + with pytest.raises(AssertionError): + loaded_tensor = load_tensors(pickle_path) + + +def test_load_tensors_invalid_pickle(tmp_path: Path) -> None: + invalid_pickle_path = tmp_path / "invalid.pickle" + model = fl.Chain(fl.Linear(1, 1)) + torch.save(model, invalid_pickle_path) # type: ignore + with pytest.raises( + pickle.UnpicklingError, + ): + load_tensors(invalid_pickle_path) diff --git a/tests/foundationals/clip/test_concepts.py b/tests/foundationals/clip/test_concepts.py index 9c4ed74..9ab1e7e 100644 --- a/tests/foundationals/clip/test_concepts.py +++ b/tests/foundationals/clip/test_concepts.py @@ -7,7 +7,7 @@ import transformers # type: ignore from diffusers import StableDiffusionPipeline # type: ignore import refiners.fluxion.layers as fl -from refiners.fluxion.utils import load_from_safetensors, no_grad +from refiners.fluxion.utils import load_from_safetensors, load_tensors, no_grad from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.tokenizer import CLIPTokenizer @@ -76,12 +76,12 @@ def prompt(request: pytest.FixtureRequest): @pytest.fixture(scope="module") def gta5_artwork_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor: - return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")[""] # type: ignore + return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")[""] @pytest.fixture(scope="module") def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor: - return torch.load(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")[""] # type: ignore + return load_tensors(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")[""] def test_tokenizer_with_special_character(): diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index 1e56685..8b21c15 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -18,7 +18,7 @@ from torch import Tensor from refiners.fluxion import manual_seed from refiners.fluxion.model_converter import ModelConverter -from refiners.fluxion.utils import image_to_tensor, no_grad +from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention from refiners.foundationals.segment_anything.model import SegmentAnythingH from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer @@ -69,7 +69,7 @@ def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> F from segment_anything import build_sam_vit_h # type: ignore sam_h = cast(FacebookSAM, build_sam_vit_h()) - sam_h.load_state_dict(state_dict=torch.load(f=facebook_sam_h_weights)) # type: ignore + sam_h.load_state_dict(state_dict=load_tensors(facebook_sam_h_weights)) return sam_h.to(device=test_device)