mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
fix weight loading for float16 LoRAs
This commit is contained in:
parent
dd0cca5855
commit
be54cfc016
|
@ -43,8 +43,8 @@ class Lora(fl.Chain):
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
||||||
self.Linear_1.weight = TorchParameter(down_weight)
|
self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
|
||||||
self.Linear_2.weight = TorchParameter(up_weight)
|
self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def up_weight(self) -> Tensor:
|
def up_weight(self) -> Tensor:
|
||||||
|
|
|
@ -788,6 +788,46 @@ def test_diffusion_lora(
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_lora_float16(
|
||||||
|
sd15_std_float16: StableDiffusion_1,
|
||||||
|
lora_data_pokemon: tuple[Image.Image, Path],
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sd15 = sd15_std_float16
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
expected_image, lora_weights_path = lora_data_pokemon
|
||||||
|
|
||||||
|
if not lora_weights_path.is_file():
|
||||||
|
warn(f"could not find weights at {lora_weights_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
|
prompt = "a cute cat"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject()
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_diffusion_lora_twice(
|
def test_diffusion_lora_twice(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
|
|
Loading…
Reference in a new issue