fix weight loading for float16 LoRAs

This commit is contained in:
Pierre Chapuis 2023-09-11 14:51:11 +02:00
parent dd0cca5855
commit be54cfc016
2 changed files with 42 additions and 2 deletions

View file

@ -43,8 +43,8 @@ class Lora(fl.Chain):
self.scale = scale
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
self.Linear_1.weight = TorchParameter(down_weight)
self.Linear_2.weight = TorchParameter(up_weight)
self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
@property
def up_weight(self) -> Tensor:

View file

@ -788,6 +788,46 @@ def test_diffusion_lora(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_lora_float16(
sd15_std_float16: StableDiffusion_1,
lora_data_pokemon: tuple[Image.Image, Path],
test_device: torch.device,
):
sd15 = sd15_std_float16
n_steps = 30
expected_image, lora_weights_path = lora_data_pokemon
if not lora_weights_path.is_file():
warn(f"could not find weights at {lora_weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps)
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject()
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
with torch.no_grad():
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_lora_twice(
sd15_std: StableDiffusion_1,