mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
modify ip_adapter's CrossAttentionAdapters injection logic
This commit is contained in:
parent
df0cc2aeb8
commit
7e64ba4011
|
@ -282,28 +282,44 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
target: fl.Attention,
|
target: fl.Attention,
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._scale = scale
|
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
clone = target.structural_copy()
|
super().__init__(target)
|
||||||
scaled_dot_product = clone.ensure_find(ScaledDotProductAttention)
|
|
||||||
image_cross_attention = ImageCrossAttention(
|
self._image_cross_attention = [
|
||||||
text_cross_attention=clone,
|
ImageCrossAttention(
|
||||||
scale=self.scale,
|
text_cross_attention=target,
|
||||||
)
|
scale=scale,
|
||||||
clone.replace(
|
|
||||||
old_module=scaled_dot_product,
|
|
||||||
new_module=fl.Sum(
|
|
||||||
scaled_dot_product,
|
|
||||||
image_cross_attention,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
super().__init__(
|
|
||||||
clone,
|
|
||||||
)
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def inject(self, parent: fl.Chain | None = None) -> "CrossAttentionAdapter":
|
||||||
|
sdpa = self.target.ensure_find(ScaledDotProductAttention)
|
||||||
|
# replace the spda by a Sum of itself and the ImageCrossAttention
|
||||||
|
self.target.replace(
|
||||||
|
old_module=sdpa,
|
||||||
|
new_module=fl.Sum(
|
||||||
|
sdpa,
|
||||||
|
self.image_cross_attention,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
# find the parent of the ImageCrossAttention (Sum)
|
||||||
|
parent = self.target.ensure_find_parent(self.image_cross_attention)
|
||||||
|
# unlink the ImageCrossAttention from its parent
|
||||||
|
parent.remove(self.image_cross_attention)
|
||||||
|
# replace the Sum by the original ScaledDotProductAttention
|
||||||
|
sdpa = parent.layer("ScaledDotProductAttention", ScaledDotProductAttention)
|
||||||
|
self.target.replace(
|
||||||
|
old_module=parent,
|
||||||
|
new_module=sdpa,
|
||||||
|
)
|
||||||
|
super().eject()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_cross_attention(self) -> ImageCrossAttention:
|
def image_cross_attention(self) -> ImageCrossAttention:
|
||||||
return self.ensure_find(ImageCrossAttention)
|
return self._image_cross_attention[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_key_projection(self) -> fl.Linear:
|
def image_key_projection(self) -> fl.Linear:
|
||||||
|
@ -315,7 +331,7 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scale(self) -> float:
|
def scale(self) -> float:
|
||||||
return self._scale
|
return self.image_cross_attention.scale
|
||||||
|
|
||||||
@scale.setter
|
@scale.setter
|
||||||
def scale(self, value: float) -> None:
|
def scale(self, value: float) -> None:
|
||||||
|
|
|
@ -35,6 +35,10 @@ def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch
|
||||||
assert repr(unet) != initial_repr
|
assert repr(unet) != initial_repr
|
||||||
adapter.eject()
|
adapter.eject()
|
||||||
assert repr(unet) == initial_repr
|
assert repr(unet) == initial_repr
|
||||||
|
adapter.inject()
|
||||||
|
assert repr(unet) != initial_repr
|
||||||
|
adapter.eject()
|
||||||
|
assert repr(unet) == initial_repr
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
|
|
Loading…
Reference in a new issue