mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
use Module's load_from_safetensors
Instead of manual calls to load_state_dict
This commit is contained in:
parent
4388968ad3
commit
d4dd45fd4d
|
@ -237,9 +237,9 @@ import torch
|
|||
|
||||
|
||||
sd15 = StableDiffusion_1(device="cuda")
|
||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors("clip.safetensors"))
|
||||
sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors"))
|
||||
sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors"))
|
||||
sd15.clip_text_encoder.load_from_safetensors("clip.safetensors")
|
||||
sd15.lda.load_from_safetensors("lda.safetensors")
|
||||
sd15.unet.load_from_safetensors("unet.safetensors")
|
||||
|
||||
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path="pokemon_lora.safetensors", scale=1.0).inject()
|
||||
|
||||
|
|
|
@ -192,9 +192,9 @@ def sd15_std(
|
|||
|
||||
sd15 = StableDiffusion_1(device=test_device)
|
||||
|
||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
|
||||
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
|
||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||
sd15.lda.load_from_safetensors(lda_weights)
|
||||
sd15.unet.load_from_safetensors(unet_weights_std)
|
||||
|
||||
return sd15
|
||||
|
||||
|
@ -209,9 +209,9 @@ def sd15_std_float16(
|
|||
|
||||
sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16)
|
||||
|
||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
|
||||
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
|
||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||
sd15.lda.load_from_safetensors(lda_weights)
|
||||
sd15.unet.load_from_safetensors(unet_weights_std)
|
||||
|
||||
return sd15
|
||||
|
||||
|
@ -227,9 +227,9 @@ def sd15_inpainting(
|
|||
unet = SD1UNet(in_channels=9)
|
||||
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
|
||||
|
||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
|
||||
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_inpainting))
|
||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||
sd15.lda.load_from_safetensors(lda_weights)
|
||||
sd15.unet.load_from_safetensors(unet_weights_inpainting)
|
||||
|
||||
return sd15
|
||||
|
||||
|
@ -245,9 +245,9 @@ def sd15_inpainting_float16(
|
|||
unet = SD1UNet(in_channels=9)
|
||||
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
|
||||
|
||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
|
||||
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_inpainting))
|
||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||
sd15.lda.load_from_safetensors(lda_weights)
|
||||
sd15.unet.load_from_safetensors(unet_weights_inpainting)
|
||||
|
||||
return sd15
|
||||
|
||||
|
@ -263,9 +263,9 @@ def sd15_ddim(
|
|||
ddim_scheduler = DDIM(num_inference_steps=20)
|
||||
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
|
||||
|
||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
|
||||
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
|
||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||
sd15.lda.load_from_safetensors(lda_weights)
|
||||
sd15.unet.load_from_safetensors(unet_weights_std)
|
||||
|
||||
return sd15
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from warnings import warn
|
|||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, tensor_to_image
|
||||
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
|
||||
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
||||
|
||||
from tests.utils import ensure_similar_images
|
||||
|
@ -38,7 +38,7 @@ def informative_drawings_weights(test_weights_path: Path) -> Path:
|
|||
@pytest.fixture
|
||||
def informative_drawings_model(informative_drawings_weights: Path, test_device: torch.device) -> InformativeDrawings:
|
||||
model = InformativeDrawings(device=test_device)
|
||||
model.load_state_dict(load_from_safetensors(informative_drawings_weights))
|
||||
model.load_from_safetensors(informative_drawings_weights)
|
||||
return model
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue