ldm: properly resize non-square init image

This commit is contained in:
Cédric Deltheil 2023-09-20 10:15:17 +02:00 committed by Cédric Deltheil
parent 01aeaf3e36
commit 2faff9f57a
2 changed files with 19 additions and 4 deletions

View file

@ -49,16 +49,16 @@ class LatentDiffusionModel(fl.Module, ABC):
first_step: int = 0, first_step: int = 0,
noise: Tensor | None = None, noise: Tensor | None = None,
) -> Tensor: ) -> Tensor:
if noise is None:
height, width = size height, width = size
if noise is None:
noise = torch.randn(1, 4, height // 8, width // 8, device=self.device) noise = torch.randn(1, 4, height // 8, width // 8, device=self.device)
assert list(noise.shape[2:]) == [ assert list(noise.shape[2:]) == [
size[0] // 8, height // 8,
size[1] // 8, width // 8,
], f"noise shape is not compatible: {noise.shape}, with size: {size}" ], f"noise shape is not compatible: {noise.shape}, with size: {size}"
if init_image is None: if init_image is None:
return noise return noise
encoded_image = self.lda.encode_image(image=init_image.resize(size=size)) encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
return self.scheduler.add_noise(x=encoded_image, noise=noise, step=self.steps[first_step]) return self.scheduler.add_noise(x=encoded_image, noise=noise, step=self.steps[first_step])
@property @property

View file

@ -492,6 +492,21 @@ def test_diffusion_std_init_image(
ensure_similar_images(predicted_image, expected_image_std_init_image) ensure_similar_images(predicted_image, expected_image_std_init_image)
@torch.no_grad()
def test_rectangular_init_latents(
sd15_std: StableDiffusion_1,
cutecat_init: Image.Image,
):
sd15 = sd15_std
# Just check latents initialization with a non-square image (and not the entire diffusion)
width, height = 512, 504
rect_init_image = cutecat_init.crop((0, 0, width, height))
x = sd15.init_latents((height, width), rect_init_image)
assert sd15.lda.decode_latents(x).size == (width, height)
@torch.no_grad() @torch.no_grad()
def test_diffusion_inpainting( def test_diffusion_inpainting(
sd15_inpainting: StableDiffusion_1_Inpainting, sd15_inpainting: StableDiffusion_1_Inpainting,