mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
use Module's load_from_safetensors
Instead of manual calls to load_state_dict
This commit is contained in:
parent
4388968ad3
commit
d4dd45fd4d
|
@ -237,9 +237,9 @@ import torch
|
||||||
|
|
||||||
|
|
||||||
sd15 = StableDiffusion_1(device="cuda")
|
sd15 = StableDiffusion_1(device="cuda")
|
||||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors("clip.safetensors"))
|
sd15.clip_text_encoder.load_from_safetensors("clip.safetensors")
|
||||||
sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors"))
|
sd15.lda.load_from_safetensors("lda.safetensors")
|
||||||
sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors"))
|
sd15.unet.load_from_safetensors("unet.safetensors")
|
||||||
|
|
||||||
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path="pokemon_lora.safetensors", scale=1.0).inject()
|
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path="pokemon_lora.safetensors", scale=1.0).inject()
|
||||||
|
|
||||||
|
|
|
@ -192,9 +192,9 @@ def sd15_std(
|
||||||
|
|
||||||
sd15 = StableDiffusion_1(device=test_device)
|
sd15 = StableDiffusion_1(device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||||
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
|
sd15.lda.load_from_safetensors(lda_weights)
|
||||||
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
|
sd15.unet.load_from_safetensors(unet_weights_std)
|
||||||
|
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
@ -209,9 +209,9 @@ def sd15_std_float16(
|
||||||
|
|
||||||
sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16)
|
sd15 = StableDiffusion_1(device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||||
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
|
sd15.lda.load_from_safetensors(lda_weights)
|
||||||
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
|
sd15.unet.load_from_safetensors(unet_weights_std)
|
||||||
|
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
@ -227,9 +227,9 @@ def sd15_inpainting(
|
||||||
unet = SD1UNet(in_channels=9)
|
unet = SD1UNet(in_channels=9)
|
||||||
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
|
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||||
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
|
sd15.lda.load_from_safetensors(lda_weights)
|
||||||
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_inpainting))
|
sd15.unet.load_from_safetensors(unet_weights_inpainting)
|
||||||
|
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
@ -245,9 +245,9 @@ def sd15_inpainting_float16(
|
||||||
unet = SD1UNet(in_channels=9)
|
unet = SD1UNet(in_channels=9)
|
||||||
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
|
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||||
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
|
sd15.lda.load_from_safetensors(lda_weights)
|
||||||
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_inpainting))
|
sd15.unet.load_from_safetensors(unet_weights_inpainting)
|
||||||
|
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
@ -263,9 +263,9 @@ def sd15_ddim(
|
||||||
ddim_scheduler = DDIM(num_inference_steps=20)
|
ddim_scheduler = DDIM(num_inference_steps=20)
|
||||||
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
|
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||||
sd15.lda.load_state_dict(load_from_safetensors(lda_weights))
|
sd15.lda.load_from_safetensors(lda_weights)
|
||||||
sd15.unet.load_state_dict(load_from_safetensors(unet_weights_std))
|
sd15.unet.load_from_safetensors(unet_weights_std)
|
||||||
|
|
||||||
return sd15
|
return sd15
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from warnings import warn
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, tensor_to_image
|
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
|
||||||
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
||||||
|
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
@ -38,7 +38,7 @@ def informative_drawings_weights(test_weights_path: Path) -> Path:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def informative_drawings_model(informative_drawings_weights: Path, test_device: torch.device) -> InformativeDrawings:
|
def informative_drawings_model(informative_drawings_weights: Path, test_device: torch.device) -> InformativeDrawings:
|
||||||
model = InformativeDrawings(device=test_device)
|
model = InformativeDrawings(device=test_device)
|
||||||
model.load_state_dict(load_from_safetensors(informative_drawings_weights))
|
model.load_from_safetensors(informative_drawings_weights)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue