mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
Add load_tensors utils in fluxion
This commit is contained in:
parent
91aea9b7ff
commit
ed3621362f
|
@ -4,7 +4,7 @@ from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.fluxion.utils import save_to_safetensors
|
from refiners.fluxion.utils import load_tensors, save_to_safetensors
|
||||||
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet
|
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet
|
||||||
|
|
||||||
# Running:
|
# Running:
|
||||||
|
@ -66,7 +66,7 @@ def main() -> None:
|
||||||
if args.output_path is None:
|
if args.output_path is None:
|
||||||
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
||||||
|
|
||||||
weights: dict[str, Any] = torch.load(f=args.source_path, map_location="cpu") # type: ignore
|
weights: dict[str, Any] = load_tensors(args.source_path, device="cpu")
|
||||||
assert isinstance(weights, dict)
|
assert isinstance(weights, dict)
|
||||||
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
|
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.fluxion.utils import save_to_safetensors
|
from refiners.fluxion.utils import load_tensors, save_to_safetensors
|
||||||
|
|
||||||
|
|
||||||
def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
|
def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
@ -148,7 +148,7 @@ def main() -> None:
|
||||||
parser.add_argument("--half", action="store_true", dest="half")
|
parser.add_argument("--half", action="store_true", dest="half")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
weights = torch.load(args.source_path) # type: ignore
|
weights = load_tensors(args.source_path)
|
||||||
convert_dinov2_facebook(weights)
|
convert_dinov2_facebook(weights)
|
||||||
if args.half:
|
if args.half:
|
||||||
weights = {key: value.half() for key, value in weights.items()}
|
weights = {key: value.half() for key, value in weights.items()}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
|
from refiners.fluxion.utils import load_tensors
|
||||||
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -27,7 +28,7 @@ class Args(argparse.Namespace):
|
||||||
|
|
||||||
def setup_converter(args: Args) -> ModelConverter:
|
def setup_converter(args: Args) -> ModelConverter:
|
||||||
source = Generator(3, 1, 3)
|
source = Generator(3, 1, 3)
|
||||||
source.load_state_dict(state_dict=torch.load(f=args.source_path, map_location="cpu")) # type: ignore
|
source.load_state_dict(state_dict=load_tensors(args.source_path))
|
||||||
source.eval()
|
source.eval()
|
||||||
target = InformativeDrawings()
|
target = InformativeDrawings()
|
||||||
x = torch.randn(1, 3, 512, 512)
|
x = torch.randn(1, 3, 512, 512)
|
||||||
|
|
|
@ -10,7 +10,7 @@ from torch import Tensor
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
from refiners.fluxion.utils import manual_seed, save_to_safetensors
|
from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors
|
||||||
from refiners.foundationals.segment_anything.image_encoder import SAMViTH
|
from refiners.foundationals.segment_anything.image_encoder import SAMViTH
|
||||||
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||||
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
||||||
|
@ -245,7 +245,7 @@ def main() -> None:
|
||||||
args = parser.parse_args(namespace=Args())
|
args = parser.parse_args(namespace=Args())
|
||||||
|
|
||||||
sam_h = build_sam_vit_h() # type: ignore
|
sam_h = build_sam_vit_h() # type: ignore
|
||||||
sam_h.load_state_dict(state_dict=torch.load(f=args.source_path)) # type: ignore
|
sam_h.load_state_dict(state_dict=load_tensors(args.source_path))
|
||||||
|
|
||||||
vit_state_dict = convert_vit(vit=sam_h.image_encoder)
|
vit_state_dict = convert_vit(vit=sam_h.image_encoder)
|
||||||
mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder)
|
mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Iterable, Literal, TypeVar
|
from typing import Any, Iterable, Literal, TypeVar, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
|
@ -173,6 +174,25 @@ def safe_open(
|
||||||
return _safe_open(str(path), framework=framework_mapping[framework], device=str(device)) # type: ignore
|
return _safe_open(str(path), framework=framework_mapping[framework], device=str(device)) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, Tensor]:
|
||||||
|
"""
|
||||||
|
Load tensors from a file saved with `torch.save` from disk using the `weights_only` mode
|
||||||
|
for additional safety (see `torch.load` for more details). Still, *only load data you trust* and
|
||||||
|
favor using `load_from_safetensors`.
|
||||||
|
"""
|
||||||
|
# see https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
|
||||||
|
tensors = torch.load(path, map_location=device, weights_only=True) # type: ignore
|
||||||
|
|
||||||
|
assert isinstance(tensors, dict) and all(
|
||||||
|
isinstance(key, str) and isinstance(value, Tensor)
|
||||||
|
for key, value in tensors.items() # type: ignore
|
||||||
|
), "Invalid tensor file, expected a dict[str, Tensor]"
|
||||||
|
|
||||||
|
return cast(dict[str, Tensor], tensors)
|
||||||
|
|
||||||
|
|
||||||
def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]:
|
def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]:
|
||||||
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
|
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
|
||||||
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore
|
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore
|
||||||
|
|
|
@ -6,7 +6,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad
|
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, load_tensors, manual_seed, no_grad
|
||||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||||
from refiners.foundationals.latent_diffusion import (
|
from refiners.foundationals.latent_diffusion import (
|
||||||
SD1ControlnetAdapter,
|
SD1ControlnetAdapter,
|
||||||
|
@ -199,7 +199,7 @@ def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Im
|
||||||
warn(f"could not find weights at {weights_path}, skipping")
|
warn(f"could not find weights at {weights_path}, skipping")
|
||||||
pytest.skip(allow_module_level=True)
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
tensors = torch.load(weights_path) # type: ignore
|
tensors = load_tensors(weights_path)
|
||||||
return expected_image, tensors
|
return expected_image, tensors
|
||||||
|
|
||||||
|
|
||||||
|
@ -282,7 +282,7 @@ def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||||
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
|
return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
|
import pickle
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -7,9 +9,11 @@ from PIL import Image
|
||||||
from torch import device as Device, dtype as DType
|
from torch import device as Device, dtype as DType
|
||||||
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
|
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
|
||||||
|
|
||||||
|
from refiners.fluxion import layers as fl
|
||||||
from refiners.fluxion.utils import (
|
from refiners.fluxion.utils import (
|
||||||
gaussian_blur,
|
gaussian_blur,
|
||||||
image_to_tensor,
|
image_to_tensor,
|
||||||
|
load_tensors,
|
||||||
manual_seed,
|
manual_seed,
|
||||||
no_grad,
|
no_grad,
|
||||||
summarize_tensor,
|
summarize_tensor,
|
||||||
|
@ -95,3 +99,28 @@ def test_no_grad() -> None:
|
||||||
|
|
||||||
w = x + 1
|
w = x + 1
|
||||||
assert w.requires_grad
|
assert w.requires_grad
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_tensors_valid_pickle(tmp_path: Path) -> None:
|
||||||
|
pickle_path = tmp_path / "valid.pickle"
|
||||||
|
|
||||||
|
tensors = {"easy-as.weight": torch.tensor([1.0, 2.0, 3.0])}
|
||||||
|
torch.save(tensors, pickle_path) # type: ignore
|
||||||
|
loaded_tensor = load_tensors(pickle_path)
|
||||||
|
assert torch.equal(loaded_tensor["easy-as.weight"], tensors["easy-as.weight"])
|
||||||
|
|
||||||
|
tensors = {"easy-as.weight": torch.tensor([1, 2, 3]), "hello": "world"}
|
||||||
|
torch.save(tensors, pickle_path) # type: ignore
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
loaded_tensor = load_tensors(pickle_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_tensors_invalid_pickle(tmp_path: Path) -> None:
|
||||||
|
invalid_pickle_path = tmp_path / "invalid.pickle"
|
||||||
|
model = fl.Chain(fl.Linear(1, 1))
|
||||||
|
torch.save(model, invalid_pickle_path) # type: ignore
|
||||||
|
with pytest.raises(
|
||||||
|
pickle.UnpicklingError,
|
||||||
|
):
|
||||||
|
load_tensors(invalid_pickle_path)
|
||||||
|
|
|
@ -7,7 +7,7 @@ import transformers # type: ignore
|
||||||
from diffusers import StableDiffusionPipeline # type: ignore
|
from diffusers import StableDiffusionPipeline # type: ignore
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.utils import load_from_safetensors, no_grad
|
from refiners.fluxion.utils import load_from_safetensors, load_tensors, no_grad
|
||||||
from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender
|
from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
@ -76,12 +76,12 @@ def prompt(request: pytest.FixtureRequest):
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def gta5_artwork_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
def gta5_artwork_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||||
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
|
return load_tensors(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||||
return torch.load(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"] # type: ignore
|
return load_tensors(test_textual_inversion_path / "cat-toy" / "learned_embeds.bin")["<cat-toy>"]
|
||||||
|
|
||||||
|
|
||||||
def test_tokenizer_with_special_character():
|
def test_tokenizer_with_special_character():
|
||||||
|
|
|
@ -18,7 +18,7 @@ from torch import Tensor
|
||||||
|
|
||||||
from refiners.fluxion import manual_seed
|
from refiners.fluxion import manual_seed
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
from refiners.fluxion.utils import image_to_tensor, no_grad
|
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
|
||||||
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
|
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
|
||||||
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
||||||
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
|
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
|
||||||
|
@ -69,7 +69,7 @@ def facebook_sam_h(facebook_sam_h_weights: Path, test_device: torch.device) -> F
|
||||||
from segment_anything import build_sam_vit_h # type: ignore
|
from segment_anything import build_sam_vit_h # type: ignore
|
||||||
|
|
||||||
sam_h = cast(FacebookSAM, build_sam_vit_h())
|
sam_h = cast(FacebookSAM, build_sam_vit_h())
|
||||||
sam_h.load_state_dict(state_dict=torch.load(f=facebook_sam_h_weights)) # type: ignore
|
sam_h.load_state_dict(state_dict=load_tensors(facebook_sam_h_weights))
|
||||||
return sam_h.to(device=test_device)
|
return sam_h.to(device=test_device)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue