mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
Add load_tensors utils in fluxion
This commit is contained in:
parent
91aea9b7ff
commit
ed3621362f
|
@ -4,7 +4,7 @@ from typing import Any
|
|||
|
||||
import torch
|
||||
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.fluxion.utils import load_tensors, save_to_safetensors
|
||||
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet
|
||||
|
||||
# Running:
|
||||
|
@ -66,7 +66,7 @@ def main() -> None:
|
|||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
||||
|
||||
weights: dict[str, Any] = torch.load(f=args.source_path, map_location="cpu") # type: ignore
|
||||
weights: dict[str, Any] = load_tensors(args.source_path, device="cpu")
|
||||
assert isinstance(weights, dict)
|
||||
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||
|
||||
import torch
|
||||
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.fluxion.utils import load_tensors, save_to_safetensors
|
||||
|
||||
|
||||
def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
|
||||
|
@ -148,7 +148,7 @@ def main() -> None:
|
|||
parser.add_argument("--half", action="store_true", dest="half")
|
||||
args = parser.parse_args()
|
||||
|
||||
weights = torch.load(args.source_path) # type: ignore
|
||||
weights = load_tensors(args.source_path)
|
||||
convert_dinov2_facebook(weights)
|
||||
if args.half:
|
||||
weights = {key: value.half() for key, value in weights.items()}
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import load_tensors
|
||||
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
||||
|
||||
try:
|
||||
|
@ -27,7 +28,7 @@ class Args(argparse.Namespace):
|
|||
|
||||
def setup_converter(args: Args) -> ModelConverter:
|
||||
source = Generator(3, 1, 3)
|
||||
source.load_state_dict(state_dict=torch.load(f=args.source_path, map_location="cpu")) # type: ignore
|
||||
source.load_state_dict(state_dict=load_tensors(args.source_path))
|
||||
source.eval()
|
||||
target = InformativeDrawings()
|
||||
x = torch.randn(1, 3, 512, 512)
|
||||
|
|
|
@ -10,7 +10,7 @@ from torch import Tensor
|
|||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import manual_seed, save_to_safetensors
|
||||
from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors
|
||||
from refiners.foundationals.segment_anything.image_encoder import SAMViTH
|
||||
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
||||
|
@ -245,7 +245,7 @@ def main() -> None:
|
|||
args = parser.parse_args(namespace=Args())
|
||||
|
||||
sam_h = build_sam_vit_h() # type: ignore
|
||||
sam_h.load_state_dict(state_dict=torch.load(f=args.source_path)) # type: ignore
|
||||
sam_h.load_state_dict(state_dict=load_tensors(args.source_path))
|
||||
|
||||
vit_state_dict = convert_vit(vit=sam_h.image_encoder)
|
||||
mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Literal, TypeVar
|
||||
from typing import Any, Iterable, Literal, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
|
@ -173,6 +174,25 @@ def safe_open(
|
|||
return _safe_open(str(path), framework=framework_mapping[framework], device=str(device)) # type: ignore
|
||||
|
||||
|
||||
def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]:
|
||||
"""
|
||||
Load tensors from a file saved with `torch.save` from disk using the `weights_only` mode
|
||||
for additional safety (see `torch.load` for more details). Still, *only load data you trust* and
|
||||
favor using `load_from_safetensors`.
|
||||
"""
|
||||
# see https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
|
||||
tensors = torch.load(path, map_location=device, weights_only=True) # type: ignore
|
||||
|
||||
assert isinstance(tensors, dict) and all(
|
||||
isinstance(key, str) and isinstance(value, Tensor)
|
||||
for key, value in tensors.items() # type: ignore
|
||||
), "Invalid tensor file, expected a dict[str, Tensor]"
|
||||
|
||||
return cast(dict[str, Tensor], tensors)
|
||||
|
||||
|
||||
def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]:
|
||||
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
|
||||
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore
|
||||
|
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad
|
||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||
from refiners.foundationals.latent_diffusion import (
|
||||
SD1ControlnetAdapter,
|
||||
|
@ -199,7 +199,7 @@ def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Im
|
|||
warn(f"could not find weights at {weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
tensors = torch.load(weights_path) # type: ignore
|
||||
tensors = load_tensors(weights_path)
|
||||
return expected_image, tensors
|
||||
|
||||
|
||||
|
@ -282,7 +282,7 @@ def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.
|
|||
|
||||
@pytest.fixture
|
||||
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
|
||||
return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from warnings import warn
|
||||
|
||||
import pytest
|
||||
|
@ -7,9 +9,11 @@ from PIL import Image
|
|||
from torch import device as Device, dtype as DType
|
||||
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
|
||||
|
||||
from refiners.fluxion import layers as fl
|
||||
from refiners.fluxion.utils import (
|
||||
gaussian_blur,
|
||||
image_to_tensor,
|
||||
load_tensors,
|
||||
manual_seed,
|
||||
no_grad,
|
||||
summarize_tensor,
|
||||
|
@ -95,3 +99,28 @@ def test_no_grad() -> None:
|
|||
|
||||
w = x + 1
|
||||
assert w.requires_grad
|
||||
|
||||
|
||||
def test_load_tensors_valid_pickle(tmp_path: Path) -> None:
|
||||
pickle_path = tmp_path / "valid.pickle"
|
||||
|
||||
tensors = {"easy-as.weight": torch.tensor([1.0, 2.0, 3.0])}
|
||||
torch.save(tensors, pickle_path) # type: ignore
|
||||
loaded_tensor = load_tensors(pickle_path)
|
||||
assert torch.equal(loaded_tensor["easy-as.weight"], tensors["easy-as.weight"])
|
||||
|
||||
tensors = {"easy-as.weight": torch.tensor([1, 2, 3]), "hello": "world"}
|
||||
torch.save(tensors, pickle_path) # type: ignore
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
loaded_tensor = load_tensors(pickle_path)
|
||||
|
||||
|
||||
def test_load_tensors_invalid_pickle(tmp_path: Path) -> None:
|
||||
invalid_pickle_path = tmp_path / "invalid.pickle"
|
||||
model = fl.Chain(fl.Linear(1, 1))
|
||||
torch.save(model, invalid_pickle_path) # type: ignore
|
||||
with pytest.raises(
|
||||
pickle.UnpicklingError,
|
||||
):
|
||||
load_tensors(invalid_pickle_path)
|
||||
|
|
|
@ -7,7 +7,7 @@ import transformers # type: ignore
|
|||
from diffusers import StableDiffusionPipeline # type: ignore
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.utils import load_from_safetensors, no_grad
|
||||
from refiners.fluxion.utils import load_from_safetensors, load_tensors, no_grad
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
|
@ -76,12 +76,12 @@ def prompt(request: pytest.FixtureRequest):
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def gta5_artwork_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
|
||||
return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||
return torch.load(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"] # type: ignore
|
||||
return load_tensors(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"]
|
||||
|
||||
|
||||
def test_tokenizer_with_special_character():
|
||||
|
|
|
@ -18,7 +18,7 @@ from torch import Tensor
|
|||
|
||||
from refiners.fluxion import manual_seed
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import image_to_tensor, no_grad
|
||||
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
|
||||
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
|
||||
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
||||
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
|
||||
|
@ -69,7 +69,7 @@ def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> F
|
|||
from segment_anything import build_sam_vit_h # type: ignore
|
||||
|
||||
sam_h = cast(FacebookSAM, build_sam_vit_h())
|
||||
sam_h.load_state_dict(state_dict=torch.load(f=facebook_sam_h_weights)) # type: ignore
|
||||
sam_h.load_state_dict(state_dict=load_tensors(facebook_sam_h_weights))
|
||||
return sam_h.to(device=test_device)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue