From d4dd45fd4d4cc011e674eff6cc880ac473deeeb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Wed, 6 Sep 2023 14:55:19 +0200 Subject: [PATCH] use Module's load_from_safetensors Instead of manual calls to load_state_dict --- README.md | 6 +++--- tests/e2e/test_diffusion.py | 30 +++++++++++++++--------------- tests/e2e/test_preprocessors.py | 4 ++-- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 3d548c1..15176bb 100644 --- a/README.md +++ b/README.md @@ -237,9 +237,9 @@ import torch sd15 = StableDiffusion_1(device="cuda") -sd15.clip_text_encoder.load_state_dict(load_from_safetensors("clip.safetensors")) -sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors")) -sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors")) +sd15.clip_text_encoder.load_from_safetensors("clip.safetensors") +sd15.lda.load_from_safetensors("lda.safetensors") +sd15.unet.load_from_safetensors("unet.safetensors") SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path="pokemon_lora.safetensors", scale=1.0).inject() diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 3e2a981..fd50da5 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -192,9 +192,9 @@ def sd15_std( sd15 = StableDiffusion_1(device=test_device) - sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights)) - sd15.lda.load_state_dict(load_from_safetensors(lda_weights)) - sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std)) + sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) + sd15.lda.load_from_safetensors(lda_weights) + sd15.unet.load_from_safetensors(unet_weights_std) return sd15 @@ -209,9 +209,9 @@ def sd15_std_float16( sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16) - sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights)) - sd15.lda.load_state_dict(load_from_safetensors(lda_weights)) - sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std)) + sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) + sd15.lda.load_from_safetensors(lda_weights) + sd15.unet.load_from_safetensors(unet_weights_std) return sd15 @@ -227,9 +227,9 @@ def sd15_inpainting( unet = SD1UNet(in_channels=9) sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device) - sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights)) - sd15.lda.load_state_dict(load_from_safetensors(lda_weights)) - sd15.unet.load_state_dict(load_from_safetensors(unet_weights_inpainting)) + sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) + sd15.lda.load_from_safetensors(lda_weights) + sd15.unet.load_from_safetensors(unet_weights_inpainting) return sd15 @@ -245,9 +245,9 @@ def sd15_inpainting_float16( unet = SD1UNet(in_channels=9) sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16) - sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights)) - sd15.lda.load_state_dict(load_from_safetensors(lda_weights)) - sd15.unet.load_state_dict(load_from_safetensors(unet_weights_inpainting)) + sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) + sd15.lda.load_from_safetensors(lda_weights) + sd15.unet.load_from_safetensors(unet_weights_inpainting) return sd15 @@ -263,9 +263,9 @@ def sd15_ddim( ddim_scheduler = DDIM(num_inference_steps=20) sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device) - sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights)) - sd15.lda.load_state_dict(load_from_safetensors(lda_weights)) - sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std)) + sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) + sd15.lda.load_from_safetensors(lda_weights) + sd15.unet.load_from_safetensors(unet_weights_std) return sd15 diff --git a/tests/e2e/test_preprocessors.py b/tests/e2e/test_preprocessors.py index 21b9355..7dd9ddc 100644 --- a/tests/e2e/test_preprocessors.py +++ b/tests/e2e/test_preprocessors.py @@ -5,7 +5,7 @@ from warnings import warn from PIL import Image from pathlib import Path -from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, tensor_to_image +from refiners.fluxion.utils import image_to_tensor, tensor_to_image from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings from tests.utils import ensure_similar_images @@ -38,7 +38,7 @@ def informative_drawings_weights(test_weights_path: Path) -> Path: @pytest.fixture def informative_drawings_model(informative_drawings_weights: Path, test_device: torch.device) -> InformativeDrawings: model = InformativeDrawings(device=test_device) - model.load_state_dict(load_from_safetensors(informative_drawings_weights)) + model.load_from_safetensors(informative_drawings_weights) return model