mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
feature: support self-attention guidance with SD1 inpainting model
This commit is contained in:
parent
ab0915d052
commit
f666bc82f5
|
@ -143,3 +143,30 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
|||
self.target_image_latents = self.lda.encode(x=masked_init_image)
|
||||
|
||||
return self.mask_latents, self.target_image_latents
|
||||
|
||||
def compute_self_attention_guidance(
|
||||
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
|
||||
) -> Tensor:
|
||||
sag = self._find_sag_adapter()
|
||||
assert sag is not None
|
||||
assert self.mask_latents is not None
|
||||
assert self.target_image_latents is not None
|
||||
|
||||
degraded_latents = sag.compute_degraded_latents(
|
||||
scheduler=self.scheduler,
|
||||
latents=x,
|
||||
noise=noise,
|
||||
step=step,
|
||||
classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
||||
x = torch.cat(
|
||||
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
|
||||
dim=1,
|
||||
)
|
||||
degraded_noise = self.unet(x)
|
||||
|
||||
return sag.scale * (noise - degraded_noise)
|
||||
|
|
Loading…
Reference in a new issue