mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
initialize StableDiffusion_1_Inpainting with a 9 channel SD1Unet if not provided
This commit is contained in:
parent
f32ccc3474
commit
7aff743019
|
@ -198,8 +198,14 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||||
) -> None:
|
) -> None:
|
||||||
self.mask_latents: Tensor | None = None
|
self.mask_latents: Tensor | None = None
|
||||||
self.target_image_latents: Tensor | None = None
|
self.target_image_latents: Tensor | None = None
|
||||||
|
unet = unet or SD1UNet(in_channels=9)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, solver=solver, device=device, dtype=dtype
|
unet=unet,
|
||||||
|
lda=lda,
|
||||||
|
clip_text_encoder=clip_text_encoder,
|
||||||
|
solver=solver,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from refiners.fluxion.utils import manual_seed, no_grad
|
from refiners.fluxion.utils import manual_seed, no_grad
|
||||||
|
from refiners.foundationals.latent_diffusion import StableDiffusion_1_Inpainting
|
||||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,3 +14,18 @@ def test_sample_noise():
|
||||||
latents_1 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64), offset_noise=0.0)
|
latents_1 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64), offset_noise=0.0)
|
||||||
|
|
||||||
assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0)
|
assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_sd1_inpainting(test_device: torch.device) -> None:
|
||||||
|
sd = StableDiffusion_1_Inpainting(device=test_device)
|
||||||
|
|
||||||
|
latent_noise = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
target_image = Image.new("RGB", (512, 512))
|
||||||
|
mask = Image.new("L", (512, 512))
|
||||||
|
|
||||||
|
sd.set_inpainting_conditions(target_image=target_image, mask=mask)
|
||||||
|
text_embedding = sd.compute_clip_text_embedding("")
|
||||||
|
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding)
|
||||||
|
|
||||||
|
assert output.shape == (1, 4, 64, 64)
|
||||||
|
|
Loading…
Reference in a new issue