Add load_tensors utils in fluxion

This commit is contained in:
limiteinductive 2024-01-19 16:37:01 +01:00 committed by Benjamin Trom
parent 91aea9b7ff
commit ed3621362f
9 changed files with 66 additions and 16 deletions

View file

@ -4,7 +4,7 @@ from typing import Any
import torch import torch
from refiners.fluxion.utils import save_to_safetensors from refiners.fluxion.utils import load_tensors, save_to_safetensors
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet
# Running: # Running:
@ -66,7 +66,7 @@ def main() -> None:
if args.output_path is None: if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors" args.output_path = f"{Path(args.source_path).stem}.safetensors"
weights: dict[str, Any] = torch.load(f=args.source_path, map_location="cpu") # type: ignore weights: dict[str, Any] = load_tensors(args.source_path, device="cpu")
assert isinstance(weights, dict) assert isinstance(weights, dict)
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"] assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]

View file

@ -3,7 +3,7 @@ from pathlib import Path
import torch import torch
from refiners.fluxion.utils import save_to_safetensors from refiners.fluxion.utils import load_tensors, save_to_safetensors
def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None: def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
@ -148,7 +148,7 @@ def main() -> None:
parser.add_argument("--half", action="store_true", dest="half") parser.add_argument("--half", action="store_true", dest="half")
args = parser.parse_args() args = parser.parse_args()
weights = torch.load(args.source_path) # type: ignore weights = load_tensors(args.source_path)
convert_dinov2_facebook(weights) convert_dinov2_facebook(weights)
if args.half: if args.half:
weights = {key: value.half() for key, value in weights.items()} weights = {key: value.half() for key, value in weights.items()}

View file

@ -5,6 +5,7 @@ import torch
from torch import nn from torch import nn
from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import load_tensors
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
try: try:
@ -27,7 +28,7 @@ class Args(argparse.Namespace):
def setup_converter(args: Args) -> ModelConverter: def setup_converter(args: Args) -> ModelConverter:
source = Generator(3, 1, 3) source = Generator(3, 1, 3)
source.load_state_dict(state_dict=torch.load(f=args.source_path, map_location="cpu")) # type: ignore source.load_state_dict(state_dict=load_tensors(args.source_path))
source.eval() source.eval()
target = InformativeDrawings() target = InformativeDrawings()
x = torch.randn(1, 3, 512, 512) x = torch.randn(1, 3, 512, 512)

View file

@ -10,7 +10,7 @@ from torch import Tensor
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import manual_seed, save_to_safetensors from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors
from refiners.foundationals.segment_anything.image_encoder import SAMViTH from refiners.foundationals.segment_anything.image_encoder import SAMViTH
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
@ -245,7 +245,7 @@ def main() -> None:
args = parser.parse_args(namespace=Args()) args = parser.parse_args(namespace=Args())
sam_h = build_sam_vit_h() # type: ignore sam_h = build_sam_vit_h() # type: ignore
sam_h.load_state_dict(state_dict=torch.load(f=args.source_path)) # type: ignore sam_h.load_state_dict(state_dict=load_tensors(args.source_path))
vit_state_dict = convert_vit(vit=sam_h.image_encoder) vit_state_dict = convert_vit(vit=sam_h.image_encoder)
mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder) mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder)

View file

@ -1,5 +1,6 @@
import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, Literal, TypeVar from typing import Any, Iterable, Literal, TypeVar, cast
import torch import torch
from jaxtyping import Float from jaxtyping import Float
@ -173,6 +174,25 @@ def safe_open(
return _safe_open(str(path), framework=framework_mapping[framework], device=str(device)) # type: ignore return _safe_open(str(path), framework=framework_mapping[framework], device=str(device)) # type: ignore
def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]:
"""
Load tensors from a file saved with `torch.save` from disk using the `weights_only` mode
for additional safety (see `torch.load` for more details). Still, *only load data you trust* and
favor using `load_from_safetensors`.
"""
# see https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
tensors = torch.load(path, map_location=device, weights_only=True) # type: ignore
assert isinstance(tensors, dict) and all(
isinstance(key, str) and isinstance(value, Tensor)
for key, value in tensors.items() # type: ignore
), "Invalid tensor file, expected a dict[str, Tensor]"
return cast(dict[str, Tensor], tensors)
def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]: def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]:
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore

View file

@ -6,7 +6,7 @@ import pytest
import torch import torch
from PIL import Image from PIL import Image
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion import ( from refiners.foundationals.latent_diffusion import (
SD1ControlnetAdapter, SD1ControlnetAdapter,
@ -199,7 +199,7 @@ def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Im
warn(f"could not find weights at {weights_path}, skipping") warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True) pytest.skip(allow_module_level=True)
tensors = torch.load(weights_path) # type: ignore tensors = load_tensors(weights_path)
return expected_image, tensors return expected_image, tensors
@ -282,7 +282,7 @@ def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.
@pytest.fixture @pytest.fixture
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor: def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View file

@ -1,4 +1,6 @@
import pickle
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from warnings import warn from warnings import warn
import pytest import pytest
@ -7,9 +9,11 @@ from PIL import Image
from torch import device as Device, dtype as DType from torch import device as Device, dtype as DType
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
from refiners.fluxion import layers as fl
from refiners.fluxion.utils import ( from refiners.fluxion.utils import (
gaussian_blur, gaussian_blur,
image_to_tensor, image_to_tensor,
load_tensors,
manual_seed, manual_seed,
no_grad, no_grad,
summarize_tensor, summarize_tensor,
@ -95,3 +99,28 @@ def test_no_grad() -> None:
w = x + 1 w = x + 1
assert w.requires_grad assert w.requires_grad
def test_load_tensors_valid_pickle(tmp_path: Path) -> None:
pickle_path = tmp_path / "valid.pickle"
tensors = {"easy-as.weight": torch.tensor([1.0, 2.0, 3.0])}
torch.save(tensors, pickle_path) # type: ignore
loaded_tensor = load_tensors(pickle_path)
assert torch.equal(loaded_tensor["easy-as.weight"], tensors["easy-as.weight"])
tensors = {"easy-as.weight": torch.tensor([1, 2, 3]), "hello": "world"}
torch.save(tensors, pickle_path) # type: ignore
with pytest.raises(AssertionError):
loaded_tensor = load_tensors(pickle_path)
def test_load_tensors_invalid_pickle(tmp_path: Path) -> None:
invalid_pickle_path = tmp_path / "invalid.pickle"
model = fl.Chain(fl.Linear(1, 1))
torch.save(model, invalid_pickle_path) # type: ignore
with pytest.raises(
pickle.UnpicklingError,
):
load_tensors(invalid_pickle_path)

View file

@ -7,7 +7,7 @@ import transformers # type: ignore
from diffusers import StableDiffusionPipeline # type: ignore from diffusers import StableDiffusionPipeline # type: ignore
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.utils import load_from_safetensors, no_grad from refiners.fluxion.utils import load_from_safetensors, load_tensors, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer from refiners.foundationals.clip.tokenizer import CLIPTokenizer
@ -76,12 +76,12 @@ def prompt(request: pytest.FixtureRequest):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def gta5_artwork_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor: def gta5_artwork_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor: def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return torch.load(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"] # type: ignore return load_tensors(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"]
def test_tokenizer_with_special_character(): def test_tokenizer_with_special_character():

View file

@ -18,7 +18,7 @@ from torch import Tensor
from refiners.fluxion import manual_seed from refiners.fluxion import manual_seed
from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import image_to_tensor, no_grad from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
from refiners.foundationals.segment_anything.model import SegmentAnythingH from refiners.foundationals.segment_anything.model import SegmentAnythingH
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
@ -69,7 +69,7 @@ def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> F
from segment_anything import build_sam_vit_h # type: ignore from segment_anything import build_sam_vit_h # type: ignore
sam_h = cast(FacebookSAM, build_sam_vit_h()) sam_h = cast(FacebookSAM, build_sam_vit_h())
sam_h.load_state_dict(state_dict=torch.load(f=facebook_sam_h_weights)) # type: ignore sam_h.load_state_dict(state_dict=load_tensors(facebook_sam_h_weights))
return sam_h.to(device=test_device) return sam_h.to(device=test_device)