From 7aff743019751ab7a488506481354522f7d0332c Mon Sep 17 00:00:00 2001 From: Laurent Date: Tue, 23 Apr 2024 14:43:32 +0000 Subject: [PATCH] initialize StableDiffusion_1_Inpainting with a 9 channel SD1Unet if not provided --- .../stable_diffusion_1/model.py | 8 +++++++- .../latent_diffusion/test_model.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index ef29151..e019ae0 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -198,8 +198,14 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1): ) -> None: self.mask_latents: Tensor | None = None self.target_image_latents: Tensor | None = None + unet = unet or SD1UNet(in_channels=9) super().__init__( - unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, solver=solver, device=device, dtype=dtype + unet=unet, + lda=lda, + clip_text_encoder=clip_text_encoder, + solver=solver, + device=device, + dtype=dtype, ) def forward( diff --git a/tests/foundationals/latent_diffusion/test_model.py b/tests/foundationals/latent_diffusion/test_model.py index 78443d4..dafd38f 100644 --- a/tests/foundationals/latent_diffusion/test_model.py +++ b/tests/foundationals/latent_diffusion/test_model.py @@ -1,6 +1,8 @@ import torch +from PIL import Image from refiners.fluxion.utils import manual_seed, no_grad +from refiners.foundationals.latent_diffusion import StableDiffusion_1_Inpainting from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel @@ -12,3 +14,18 @@ def test_sample_noise(): latents_1 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64), offset_noise=0.0) assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0) + + +@no_grad() +def test_sd1_inpainting(test_device: torch.device) -> None: + sd = StableDiffusion_1_Inpainting(device=test_device) + + latent_noise = torch.randn(1, 4, 64, 64, device=test_device) + target_image = Image.new("RGB", (512, 512)) + mask = Image.new("L", (512, 512)) + + sd.set_inpainting_conditions(target_image=target_image, mask=mask) + text_embedding = sd.compute_clip_text_embedding("") + output = sd(latent_noise, step=0, clip_text_embedding=text_embedding) + + assert output.shape == (1, 4, 64, 64)