mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
modify ip_adapter's CrossAttentionAdapters injection logic
This commit is contained in:
parent
df0cc2aeb8
commit
7e64ba4011
|
@ -282,28 +282,44 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
|||
target: fl.Attention,
|
||||
scale: float = 1.0,
|
||||
) -> None:
|
||||
self._scale = scale
|
||||
with self.setup_adapter(target):
|
||||
clone = target.structural_copy()
|
||||
scaled_dot_product = clone.ensure_find(ScaledDotProductAttention)
|
||||
image_cross_attention = ImageCrossAttention(
|
||||
text_cross_attention=clone,
|
||||
scale=self.scale,
|
||||
super().__init__(target)
|
||||
|
||||
self._image_cross_attention = [
|
||||
ImageCrossAttention(
|
||||
text_cross_attention=target,
|
||||
scale=scale,
|
||||
)
|
||||
clone.replace(
|
||||
old_module=scaled_dot_product,
|
||||
]
|
||||
|
||||
def inject(self, parent: fl.Chain | None = None) -> "CrossAttentionAdapter":
|
||||
sdpa = self.target.ensure_find(ScaledDotProductAttention)
|
||||
# replace the spda by a Sum of itself and the ImageCrossAttention
|
||||
self.target.replace(
|
||||
old_module=sdpa,
|
||||
new_module=fl.Sum(
|
||||
scaled_dot_product,
|
||||
image_cross_attention,
|
||||
sdpa,
|
||||
self.image_cross_attention,
|
||||
),
|
||||
)
|
||||
super().__init__(
|
||||
clone,
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
# find the parent of the ImageCrossAttention (Sum)
|
||||
parent = self.target.ensure_find_parent(self.image_cross_attention)
|
||||
# unlink the ImageCrossAttention from its parent
|
||||
parent.remove(self.image_cross_attention)
|
||||
# replace the Sum by the original ScaledDotProductAttention
|
||||
sdpa = parent.layer("ScaledDotProductAttention", ScaledDotProductAttention)
|
||||
self.target.replace(
|
||||
old_module=parent,
|
||||
new_module=sdpa,
|
||||
)
|
||||
super().eject()
|
||||
|
||||
@property
|
||||
def image_cross_attention(self) -> ImageCrossAttention:
|
||||
return self.ensure_find(ImageCrossAttention)
|
||||
return self._image_cross_attention[0]
|
||||
|
||||
@property
|
||||
def image_key_projection(self) -> fl.Linear:
|
||||
|
@ -315,7 +331,7 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
|||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
return self._scale
|
||||
return self.image_cross_attention.scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
|
|
|
@ -35,6 +35,10 @@ def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch
|
|||
assert repr(unet) != initial_repr
|
||||
adapter.eject()
|
||||
assert repr(unet) == initial_repr
|
||||
adapter.inject()
|
||||
assert repr(unet) != initial_repr
|
||||
adapter.eject()
|
||||
assert repr(unet) == initial_repr
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
|
Loading…
Reference in a new issue