use Module's load_from_safetensors

Instead of manual calls to load_state_dict
This commit is contained in:
Cédric Deltheil 2023-09-06 14:55:19 +02:00 committed by Cédric Deltheil
parent 4388968ad3
commit d4dd45fd4d
3 changed files with 20 additions and 20 deletions

View file

@ -237,9 +237,9 @@ import torch
sd15 = StableDiffusion_1(device="cuda")
sd15.clip_text_encoder.load_state_dict(load_from_safetensors("clip.safetensors"))
sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors"))
sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors"))
sd15.clip_text_encoder.load_from_safetensors("clip.safetensors")
sd15.lda.load_from_safetensors("lda.safetensors")
sd15.unet.load_from_safetensors("unet.safetensors")
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path="pokemon_lora.safetensors", scale=1.0).inject()

View file

@ -192,9 +192,9 @@ def sd15_std(
sd15 = StableDiffusion_1(device=test_device)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
return sd15
@ -209,9 +209,9 @@ def sd15_std_float16(
sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
return sd15
@ -227,9 +227,9 @@ def sd15_inpainting(
unet = SD1UNet(in_channels=9)
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_inpainting))
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_inpainting)
return sd15
@ -245,9 +245,9 @@ def sd15_inpainting_float16(
unet = SD1UNet(in_channels=9)
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_inpainting))
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_inpainting)
return sd15
@ -263,9 +263,9 @@ def sd15_ddim(
ddim_scheduler = DDIM(num_inference_steps=20)
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
return sd15

View file

@ -5,7 +5,7 @@ from warnings import warn
from PIL import Image
from pathlib import Path
from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, tensor_to_image
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
from tests.utils import ensure_similar_images
@ -38,7 +38,7 @@ def informative_drawings_weights(test_weights_path: Path) -> Path:
@pytest.fixture
def informative_drawings_model(informative_drawings_weights: Path, test_device: torch.device) -> InformativeDrawings:
model = InformativeDrawings(device=test_device)
model.load_state_dict(load_from_safetensors(informative_drawings_weights))
model.load_from_safetensors(informative_drawings_weights)
return model