mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
add static typing to __call__ method for latent_diffusion models ; fix multi_diffusion bug that wasn't taking guidance_scale into account
This commit is contained in:
parent
a2ee705783
commit
f26b6ee00a
|
@ -68,6 +68,14 @@ class StableDiffusion_1(LatentDiffusionModel):
|
|||
dtype=dtype,
|
||||
)
|
||||
|
||||
def __call__(self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5) -> Tensor:
|
||||
return super().__call__(
|
||||
x,
|
||||
step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=condition_scale,
|
||||
)
|
||||
|
||||
def compute_clip_text_embedding(self, text: str | list[str], negative_text: str | list[str] = "") -> Tensor:
|
||||
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ class SD1MultiDiffusion(MultiDiffusion[StableDiffusion_1, DiffusionTarget]):
|
|||
x=x,
|
||||
step=step,
|
||||
clip_text_embedding=target.clip_text_embedding,
|
||||
scale=target.condition_scale,
|
||||
condition_scale=target.condition_scale,
|
||||
)
|
||||
|
||||
|
||||
|
@ -37,5 +37,5 @@ class SD1InpaintingMultiDiffusion(MultiDiffusion[StableDiffusion_1_Inpainting, I
|
|||
x=x,
|
||||
step=step,
|
||||
clip_text_embedding=target.clip_text_embedding,
|
||||
scale=target.condition_scale,
|
||||
condition_scale=target.condition_scale,
|
||||
)
|
||||
|
|
|
@ -65,6 +65,25 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
|||
dtype=dtype,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: Tensor,
|
||||
step: int,
|
||||
*,
|
||||
clip_text_embedding: Tensor,
|
||||
pooled_text_embedding: Tensor,
|
||||
time_ids: Tensor,
|
||||
condition_scale: float = 5.0,
|
||||
) -> Tensor:
|
||||
return super().__call__(
|
||||
x=x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
condition_scale=condition_scale,
|
||||
)
|
||||
|
||||
def compute_clip_text_embedding(
|
||||
self, text: str | list[str], negative_text: str | list[str] = ""
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
|
|
|
@ -2038,6 +2038,7 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
|
|||
size=(64, 64),
|
||||
offset=(0, 16),
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=3,
|
||||
start_step=0,
|
||||
)
|
||||
noise = torch.randn(1, 4, 64, 80, device=sd.device, dtype=sd.dtype)
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 622 KiB After Width: | Height: | Size: 629 KiB |
Loading…
Reference in a new issue