From be54cfc016def31ddfa84b1811b4206d4e06ce8e Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Mon, 11 Sep 2023 14:51:11 +0200 Subject: [PATCH] fix weight loading for float16 LoRAs --- src/refiners/fluxion/adapters/lora.py | 4 +-- tests/e2e/test_diffusion.py | 40 +++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/src/refiners/fluxion/adapters/lora.py b/src/refiners/fluxion/adapters/lora.py index e6b55c2..fe2b067 100644 --- a/src/refiners/fluxion/adapters/lora.py +++ b/src/refiners/fluxion/adapters/lora.py @@ -43,8 +43,8 @@ class Lora(fl.Chain): self.scale = scale def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None: - self.Linear_1.weight = TorchParameter(down_weight) - self.Linear_2.weight = TorchParameter(up_weight) + self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype)) + self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype)) @property def up_weight(self) -> Tensor: diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 55a8aec..2ac62a7 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -788,6 +788,46 @@ def test_diffusion_lora( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) +@torch.no_grad() +def test_diffusion_lora_float16( + sd15_std_float16: StableDiffusion_1, + lora_data_pokemon: tuple[Image.Image, Path], + test_device: torch.device, +): + sd15 = sd15_std_float16 + n_steps = 30 + + expected_image, lora_weights_path = lora_data_pokemon + + if not lora_weights_path.is_file(): + warn(f"could not find weights at {lora_weights_path}, skipping") + pytest.skip(allow_module_level=True) + + prompt = "a cute cat" + + with torch.no_grad(): + clip_text_embedding = sd15.compute_clip_text_embedding(prompt) + + sd15.set_num_inference_steps(n_steps) + + SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject() + + manual_seed(2) + x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) + + with torch.no_grad(): + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) + + ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98) + + @torch.no_grad() def test_diffusion_lora_twice( sd15_std: StableDiffusion_1,