mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
ldm: properly resize non-square init image
This commit is contained in:
parent
01aeaf3e36
commit
2faff9f57a
|
@ -49,16 +49,16 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
first_step: int = 0,
|
first_step: int = 0,
|
||||||
noise: Tensor | None = None,
|
noise: Tensor | None = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
height, width = size
|
||||||
if noise is None:
|
if noise is None:
|
||||||
height, width = size
|
|
||||||
noise = torch.randn(1, 4, height // 8, width // 8, device=self.device)
|
noise = torch.randn(1, 4, height // 8, width // 8, device=self.device)
|
||||||
assert list(noise.shape[2:]) == [
|
assert list(noise.shape[2:]) == [
|
||||||
size[0] // 8,
|
height // 8,
|
||||||
size[1] // 8,
|
width // 8,
|
||||||
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
|
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
|
||||||
if init_image is None:
|
if init_image is None:
|
||||||
return noise
|
return noise
|
||||||
encoded_image = self.lda.encode_image(image=init_image.resize(size=size))
|
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
|
||||||
return self.scheduler.add_noise(x=encoded_image, noise=noise, step=self.steps[first_step])
|
return self.scheduler.add_noise(x=encoded_image, noise=noise, step=self.steps[first_step])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -492,6 +492,21 @@ def test_diffusion_std_init_image(
|
||||||
ensure_similar_images(predicted_image, expected_image_std_init_image)
|
ensure_similar_images(predicted_image, expected_image_std_init_image)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_rectangular_init_latents(
|
||||||
|
sd15_std: StableDiffusion_1,
|
||||||
|
cutecat_init: Image.Image,
|
||||||
|
):
|
||||||
|
sd15 = sd15_std
|
||||||
|
|
||||||
|
# Just check latents initialization with a non-square image (and not the entire diffusion)
|
||||||
|
width, height = 512, 504
|
||||||
|
rect_init_image = cutecat_init.crop((0, 0, width, height))
|
||||||
|
x = sd15.init_latents((height, width), rect_init_image)
|
||||||
|
|
||||||
|
assert sd15.lda.decode_latents(x).size == (width, height)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_diffusion_inpainting(
|
def test_diffusion_inpainting(
|
||||||
sd15_inpainting: StableDiffusion_1_Inpainting,
|
sd15_inpainting: StableDiffusion_1_Inpainting,
|
||||||
|
|
Loading…
Reference in a new issue