feature: support self-attention guidance with SD1 inpainting model

This commit is contained in:
Bryce 2023-11-18 12:17:10 -08:00 committed by Cédric Deltheil
parent ab0915d052
commit f666bc82f5

View file

@ -143,3 +143,30 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
self.target_image_latents = self.lda.encode(x=masked_init_image) self.target_image_latents = self.lda.encode(x=masked_init_image)
return self.mask_latents, self.target_image_latents return self.mask_latents, self.target_image_latents
def compute_self_attention_guidance(
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
) -> Tensor:
sag = self._find_sag_adapter()
assert sag is not None
assert self.mask_latents is not None
assert self.target_image_latents is not None
degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler,
latents=x,
noise=noise,
step=step,
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
x = torch.cat(
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
dim=1,
)
degraded_noise = self.unet(x)
return sag.scale * (noise - degraded_noise)