deprecate LatentDiffusionAutoencoder's decode_latents

This commit is contained in:
Laurent 2024-10-15 12:52:24 +00:00 committed by Laureηt
parent d8f77dd880
commit 241abfafe4
5 changed files with 15 additions and 20 deletions

View file

@ -92,7 +92,7 @@ with no_grad(): # Disable gradient calculation for memory-efficient inference
pooled_text_embedding=pooled_text_embedding, pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids, time_ids=time_ids,
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
predicted_image.save("vanilla_sdxl.png") predicted_image.save("vanilla_sdxl.png")
@ -145,7 +145,7 @@ predicted_image.save("vanilla_sdxl.png")
pooled_text_embedding=pooled_text_embedding, pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids, time_ids=time_ids,
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
predicted_image.save("vanilla_sdxl.png") predicted_image.save("vanilla_sdxl.png")
@ -318,7 +318,7 @@ manager.add_loras("pixel-art-lora", load_from_safetensors("pixel-art-xl-v1.1.saf
pooled_text_embedding=pooled_text_embedding, pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids, time_ids=time_ids,
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
predicted_image.save("scifi_pixel_sdxl.png") predicted_image.save("scifi_pixel_sdxl.png")
@ -453,7 +453,7 @@ with torch.no_grad():
pooled_text_embedding=pooled_text_embedding, pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids, time_ids=time_ids,
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
predicted_image.save("scifi_pixel_IP_sdxl.png") predicted_image.save("scifi_pixel_IP_sdxl.png")
@ -591,7 +591,7 @@ with torch.no_grad():
pooled_text_embedding=pooled_text_embedding, pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids, time_ids=time_ids,
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
predicted_image.save("scifi_pixel_IP_T2I_sdxl.png") predicted_image.save("scifi_pixel_IP_T2I_sdxl.png")

View file

@ -369,11 +369,6 @@ class LatentDiffusionAutoencoder(Chain):
x = 2 * x - 1 x = 2 * x - 1
return self.encode(x) return self.encode(x)
# backward-compatibility alias
# TODO: deprecate this method
def decode_latents(self, x: Tensor) -> Image.Image:
return self.latents_to_image(x)
def latents_to_image(self, x: Tensor) -> Image.Image: def latents_to_image(self, x: Tensor) -> Image.Image:
""" """
Decode latents to an image. Decode latents to an image.

View file

@ -61,7 +61,7 @@ class StableDiffusion_1(LatentDiffusionModel):
for step in sd15.steps: for step in sd15.steps:
x = sd15(x, step=step, clip_text_embedding=clip_text_embedding) x = sd15(x, step=step, clip_text_embedding=clip_text_embedding)
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
predicted_image.save("output.png") predicted_image.save("output.png")
``` ```
""" """

View file

@ -1496,7 +1496,7 @@ def test_diffusion_sdxl_control_lora(
) )
# decode latent to image # decode latent to image
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
# ensure the predicted image is similar to the expected image # ensure the predicted image is similar to the expected image
ensure_similar_images( ensure_similar_images(
@ -1935,7 +1935,7 @@ def test_diffusion_ip_adapter_multi(
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=7.5, condition_scale=7.5,
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_ip_adapter_multi, min_psnr=43, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image_ip_adapter_multi, min_psnr=43, min_ssim=0.98)
@ -2245,7 +2245,7 @@ def test_diffusion_sdxl_sliced_attention(
condition_scale=5, condition_scale=5,
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@ -2279,7 +2279,7 @@ def test_diffusion_sdxl_euler_deterministic(
condition_scale=5, condition_scale=5,
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image) ensure_similar_images(predicted_image, expected_image)
@ -2604,7 +2604,7 @@ def test_style_aligned(
) )
# decode latents # decode latents
predicted_images = [sdxl.lda.decode_latents(latent.unsqueeze(0)) for latent in x] predicted_images = sdxl.lda.latents_to_images(x)
# tile all images horizontally # tile all images horizontally
merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024)) merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024))

View file

@ -109,7 +109,7 @@ def test_guide_adapting_sdxl_vanilla(
time_ids=time_ids, time_ids=time_ids,
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@ -151,7 +151,7 @@ def test_guide_adapting_sdxl_single_lora(
time_ids=time_ids, time_ids=time_ids,
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98)
@ -195,7 +195,7 @@ def test_guide_adapting_sdxl_multiple_loras(
time_ids=time_ids, time_ids=time_ids,
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98)
@ -255,7 +255,7 @@ def test_guide_adapting_sdxl_loras_ip_adapter(
time_ids=time_ids, time_ids=time_ids,
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=29, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=29, min_ssim=0.98)