modify ip_adapter's CrossAttentionAdapters injection logic

This commit is contained in:
Laurent 2024-03-26 09:48:36 +00:00 committed by Cédric Deltheil
parent df0cc2aeb8
commit 7e64ba4011
2 changed files with 38 additions and 18 deletions

View file

@ -282,28 +282,44 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
target: fl.Attention, target: fl.Attention,
scale: float = 1.0, scale: float = 1.0,
) -> None: ) -> None:
self._scale = scale
with self.setup_adapter(target): with self.setup_adapter(target):
clone = target.structural_copy() super().__init__(target)
scaled_dot_product = clone.ensure_find(ScaledDotProductAttention)
image_cross_attention = ImageCrossAttention( self._image_cross_attention = [
text_cross_attention=clone, ImageCrossAttention(
scale=self.scale, text_cross_attention=target,
scale=scale,
) )
clone.replace( ]
old_module=scaled_dot_product,
def inject(self, parent: fl.Chain | None = None) -> "CrossAttentionAdapter":
sdpa = self.target.ensure_find(ScaledDotProductAttention)
# replace the spda by a Sum of itself and the ImageCrossAttention
self.target.replace(
old_module=sdpa,
new_module=fl.Sum( new_module=fl.Sum(
scaled_dot_product, sdpa,
image_cross_attention, self.image_cross_attention,
), ),
) )
super().__init__( return super().inject(parent)
clone,
def eject(self) -> None:
# find the parent of the ImageCrossAttention (Sum)
parent = self.target.ensure_find_parent(self.image_cross_attention)
# unlink the ImageCrossAttention from its parent
parent.remove(self.image_cross_attention)
# replace the Sum by the original ScaledDotProductAttention
sdpa = parent.layer("ScaledDotProductAttention", ScaledDotProductAttention)
self.target.replace(
old_module=parent,
new_module=sdpa,
) )
super().eject()
@property @property
def image_cross_attention(self) -> ImageCrossAttention: def image_cross_attention(self) -> ImageCrossAttention:
return self.ensure_find(ImageCrossAttention) return self._image_cross_attention[0]
@property @property
def image_key_projection(self) -> fl.Linear: def image_key_projection(self) -> fl.Linear:
@ -315,7 +331,7 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
@property @property
def scale(self) -> float: def scale(self) -> float:
return self._scale return self.image_cross_attention.scale
@scale.setter @scale.setter
def scale(self, value: float) -> None: def scale(self, value: float) -> None:

View file

@ -35,6 +35,10 @@ def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch
assert repr(unet) != initial_repr assert repr(unet) != initial_repr
adapter.eject() adapter.eject()
assert repr(unet) == initial_repr assert repr(unet) == initial_repr
adapter.inject()
assert repr(unet) != initial_repr
adapter.eject()
assert repr(unet) == initial_repr
@no_grad() @no_grad()