mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add static typing to __call__ method for latent_diffusion models ; fix multi_diffusion bug that wasn't taking guidance_scale into account
This commit is contained in:
parent
a2ee705783
commit
f26b6ee00a
|
@ -68,6 +68,14 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5) -> Tensor:
|
||||||
|
return super().__call__(
|
||||||
|
x,
|
||||||
|
step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=condition_scale,
|
||||||
|
)
|
||||||
|
|
||||||
def compute_clip_text_embedding(self, text: str | list[str], negative_text: str | list[str] = "") -> Tensor:
|
def compute_clip_text_embedding(self, text: str | list[str], negative_text: str | list[str] = "") -> Tensor:
|
||||||
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
|
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ class SD1MultiDiffusion(MultiDiffusion[StableDiffusion_1, DiffusionTarget]):
|
||||||
x=x,
|
x=x,
|
||||||
step=step,
|
step=step,
|
||||||
clip_text_embedding=target.clip_text_embedding,
|
clip_text_embedding=target.clip_text_embedding,
|
||||||
scale=target.condition_scale,
|
condition_scale=target.condition_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,5 +37,5 @@ class SD1InpaintingMultiDiffusion(MultiDiffusion[StableDiffusion_1_Inpainting, I
|
||||||
x=x,
|
x=x,
|
||||||
step=step,
|
step=step,
|
||||||
clip_text_embedding=target.clip_text_embedding,
|
clip_text_embedding=target.clip_text_embedding,
|
||||||
scale=target.condition_scale,
|
condition_scale=target.condition_scale,
|
||||||
)
|
)
|
||||||
|
|
|
@ -65,6 +65,25 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
step: int,
|
||||||
|
*,
|
||||||
|
clip_text_embedding: Tensor,
|
||||||
|
pooled_text_embedding: Tensor,
|
||||||
|
time_ids: Tensor,
|
||||||
|
condition_scale: float = 5.0,
|
||||||
|
) -> Tensor:
|
||||||
|
return super().__call__(
|
||||||
|
x=x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
pooled_text_embedding=pooled_text_embedding,
|
||||||
|
time_ids=time_ids,
|
||||||
|
condition_scale=condition_scale,
|
||||||
|
)
|
||||||
|
|
||||||
def compute_clip_text_embedding(
|
def compute_clip_text_embedding(
|
||||||
self, text: str | list[str], negative_text: str | list[str] = ""
|
self, text: str | list[str], negative_text: str | list[str] = ""
|
||||||
) -> tuple[Tensor, Tensor]:
|
) -> tuple[Tensor, Tensor]:
|
||||||
|
|
|
@ -2038,6 +2038,7 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
|
||||||
size=(64, 64),
|
size=(64, 64),
|
||||||
offset=(0, 16),
|
offset=(0, 16),
|
||||||
clip_text_embedding=clip_text_embedding,
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=3,
|
||||||
start_step=0,
|
start_step=0,
|
||||||
)
|
)
|
||||||
noise = torch.randn(1, 4, 64, 80, device=sd.device, dtype=sd.dtype)
|
noise = torch.randn(1, 4, 64, 80, device=sd.device, dtype=sd.dtype)
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 622 KiB After Width: | Height: | Size: 629 KiB |
Loading…
Reference in a new issue