mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
fix weight loading for float16 LoRAs
This commit is contained in:
parent
dd0cca5855
commit
be54cfc016
|
@ -43,8 +43,8 @@ class Lora(fl.Chain):
|
|||
self.scale = scale
|
||||
|
||||
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
||||
self.Linear_1.weight = TorchParameter(down_weight)
|
||||
self.Linear_2.weight = TorchParameter(up_weight)
|
||||
self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
|
||||
self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
|
||||
|
||||
@property
|
||||
def up_weight(self) -> Tensor:
|
||||
|
|
|
@ -788,6 +788,46 @@ def test_diffusion_lora(
|
|||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_diffusion_lora_float16(
|
||||
sd15_std_float16: StableDiffusion_1,
|
||||
lora_data_pokemon: tuple[Image.Image, Path],
|
||||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_std_float16
|
||||
n_steps = 30
|
||||
|
||||
expected_image, lora_weights_path = lora_data_pokemon
|
||||
|
||||
if not lora_weights_path.is_file():
|
||||
warn(f"could not find weights at {lora_weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
prompt = "a cute cat"
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject()
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||
|
||||
with torch.no_grad():
|
||||
for step in sd15.steps:
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_diffusion_lora_twice(
|
||||
sd15_std: StableDiffusion_1,
|
||||
|
|
Loading…
Reference in a new issue