From 2faff9f57a164369e4f56e551074181867d57147 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Wed, 20 Sep 2023 10:15:17 +0200 Subject: [PATCH] ldm: properly resize non-square init image --- .../foundationals/latent_diffusion/model.py | 8 ++++---- tests/e2e/test_diffusion.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 952fa90..353a598 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -49,16 +49,16 @@ class LatentDiffusionModel(fl.Module, ABC): first_step: int = 0, noise: Tensor | None = None, ) -> Tensor: + height, width = size if noise is None: - height, width = size noise = torch.randn(1, 4, height // 8, width // 8, device=self.device) assert list(noise.shape[2:]) == [ - size[0] // 8, - size[1] // 8, + height // 8, + width // 8, ], f"noise shape is not compatible: {noise.shape}, with size: {size}" if init_image is None: return noise - encoded_image = self.lda.encode_image(image=init_image.resize(size=size)) + encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height))) return self.scheduler.add_noise(x=encoded_image, noise=noise, step=self.steps[first_step]) @property diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index d7aa74b..11ff09f 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -492,6 +492,21 @@ def test_diffusion_std_init_image( ensure_similar_images(predicted_image, expected_image_std_init_image) +@torch.no_grad() +def test_rectangular_init_latents( + sd15_std: StableDiffusion_1, + cutecat_init: Image.Image, +): + sd15 = sd15_std + + # Just check latents initialization with a non-square image (and not the entire diffusion) + width, height = 512, 504 + rect_init_image = cutecat_init.crop((0, 0, width, height)) + x = sd15.init_latents((height, width), rect_init_image) + + assert sd15.lda.decode_latents(x).size == (width, height) + + @torch.no_grad() def test_diffusion_inpainting( sd15_inpainting: StableDiffusion_1_Inpainting,