add static typing to __call__ method for latent_diffusion models ; fix multi_diffusion bug that wasn't taking guidance_scale into account

This commit is contained in:
limiteinductive 2024-04-11 09:01:15 +00:00 committed by Benjamin Trom
parent a2ee705783
commit f26b6ee00a
5 changed files with 30 additions and 2 deletions

View file

@ -68,6 +68,14 @@ class StableDiffusion_1(LatentDiffusionModel):
dtype=dtype,
)
def __call__(self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5) -> Tensor:
return super().__call__(
x,
step,
clip_text_embedding=clip_text_embedding,
condition_scale=condition_scale,
)
def compute_clip_text_embedding(self, text: str | list[str], negative_text: str | list[str] = "") -> Tensor:
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.

View file

@ -16,7 +16,7 @@ class SD1MultiDiffusion(MultiDiffusion[StableDiffusion_1, DiffusionTarget]):
x=x,
step=step,
clip_text_embedding=target.clip_text_embedding,
scale=target.condition_scale,
condition_scale=target.condition_scale,
)
@ -37,5 +37,5 @@ class SD1InpaintingMultiDiffusion(MultiDiffusion[StableDiffusion_1_Inpainting, I
x=x,
step=step,
clip_text_embedding=target.clip_text_embedding,
scale=target.condition_scale,
condition_scale=target.condition_scale,
)

View file

@ -65,6 +65,25 @@ class StableDiffusion_XL(LatentDiffusionModel):
dtype=dtype,
)
def __call__(
self,
x: Tensor,
step: int,
*,
clip_text_embedding: Tensor,
pooled_text_embedding: Tensor,
time_ids: Tensor,
condition_scale: float = 5.0,
) -> Tensor:
return super().__call__(
x=x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=condition_scale,
)
def compute_clip_text_embedding(
self, text: str | list[str], negative_text: str | list[str] = ""
) -> tuple[Tensor, Tensor]:

View file

@ -2038,6 +2038,7 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
size=(64, 64),
offset=(0, 16),
clip_text_embedding=clip_text_embedding,
condition_scale=3,
start_step=0,
)
noise = torch.randn(1, 4, 64, 80, device=sd.device, dtype=sd.dtype)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 622 KiB

After

Width:  |  Height:  |  Size: 629 KiB