From 7e64ba4011997aba4efc5d09c68dbc25e5e83e91 Mon Sep 17 00:00:00 2001 From: Laurent Date: Tue, 26 Mar 2024 09:48:36 +0000 Subject: [PATCH] modify ip_adapter's CrossAttentionAdapters injection logic --- .../latent_diffusion/image_prompt.py | 52 ++++++++++++------- tests/adapters/test_ip_adapter.py | 4 ++ 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 768a110..a9c4f16 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -282,28 +282,44 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): target: fl.Attention, scale: float = 1.0, ) -> None: - self._scale = scale with self.setup_adapter(target): - clone = target.structural_copy() - scaled_dot_product = clone.ensure_find(ScaledDotProductAttention) - image_cross_attention = ImageCrossAttention( - text_cross_attention=clone, - scale=self.scale, - ) - clone.replace( - old_module=scaled_dot_product, - new_module=fl.Sum( - scaled_dot_product, - image_cross_attention, - ), - ) - super().__init__( - clone, + super().__init__(target) + + self._image_cross_attention = [ + ImageCrossAttention( + text_cross_attention=target, + scale=scale, ) + ] + + def inject(self, parent: fl.Chain | None = None) -> "CrossAttentionAdapter": + sdpa = self.target.ensure_find(ScaledDotProductAttention) + # replace the spda by a Sum of itself and the ImageCrossAttention + self.target.replace( + old_module=sdpa, + new_module=fl.Sum( + sdpa, + self.image_cross_attention, + ), + ) + return super().inject(parent) + + def eject(self) -> None: + # find the parent of the ImageCrossAttention (Sum) + parent = self.target.ensure_find_parent(self.image_cross_attention) + # unlink the ImageCrossAttention from its parent + parent.remove(self.image_cross_attention) + # replace the Sum by the original ScaledDotProductAttention + sdpa = parent.layer("ScaledDotProductAttention", ScaledDotProductAttention) + self.target.replace( + old_module=parent, + new_module=sdpa, + ) + super().eject() @property def image_cross_attention(self) -> ImageCrossAttention: - return self.ensure_find(ImageCrossAttention) + return self._image_cross_attention[0] @property def image_key_projection(self) -> fl.Linear: @@ -315,7 +331,7 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): @property def scale(self) -> float: - return self._scale + return self.image_cross_attention.scale @scale.setter def scale(self, value: float) -> None: diff --git a/tests/adapters/test_ip_adapter.py b/tests/adapters/test_ip_adapter.py index 570a46d..8c89c78 100644 --- a/tests/adapters/test_ip_adapter.py +++ b/tests/adapters/test_ip_adapter.py @@ -35,6 +35,10 @@ def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch assert repr(unet) != initial_repr adapter.eject() assert repr(unet) == initial_repr + adapter.inject() + assert repr(unet) != initial_repr + adapter.eject() + assert repr(unet) == initial_repr @no_grad()