simplify even more CrossAttentionAdapter

Following Laurent2916's idea: see #167
This commit is contained in:
limiteinductive 2024-01-11 10:29:21 +01:00 committed by Benjamin Trom
parent 3ab8ed2989
commit deed703617

View file

@ -9,7 +9,6 @@ import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.context import Contexts from refiners.fluxion.context import Contexts
from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.layers.chain import Distribute
from refiners.fluxion.utils import image_to_tensor, normalize from refiners.fluxion.utils import image_to_tensor, normalize
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
@ -234,16 +233,12 @@ class PerceiverResampler(fl.Chain):
return {"perceiver_resampler": {"x": None}} return {"perceiver_resampler": {"x": None}}
class InjectionPoint(fl.Chain):
pass
class ImageCrossAttention(fl.Chain): class ImageCrossAttention(fl.Chain):
def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None: def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
self.scale = scale self._scale = scale
super().__init__( super().__init__(
fl.Distribute( fl.Distribute(
fl.UseContext(context="ip_adapter", key="query_projection"), fl.Identity(),
fl.Chain( fl.Chain(
fl.UseContext(context="ip_adapter", key="clip_image_embedding"), fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
fl.Linear( fl.Linear(
@ -271,10 +266,14 @@ class ImageCrossAttention(fl.Chain):
fl.Multiply(self.scale), fl.Multiply(self.scale),
) )
@property
def scale(self) -> float:
return self._scale
class SetQueryProjection(fl.Passthrough): @scale.setter
def __init__(self) -> None: def scale(self, value: float) -> None:
super().__init__(fl.GetArg(index=0), fl.SetContext(context="ip_adapter", key="query_projection")) self._scale = value
self.ensure_find(fl.Multiply).scale = value
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
@ -283,15 +282,24 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
target: fl.Attention, target: fl.Attention,
scale: float = 1.0, scale: float = 1.0,
) -> None: ) -> None:
self._scale = scale
with self.setup_adapter(target): with self.setup_adapter(target):
super().__init__( clone = target.structural_copy()
fl.Sum( scaled_dot_product = clone.ensure_find(ScaledDotProductAttention)
target[:-1], # original text cross attention image_cross_attention = ImageCrossAttention(
ImageCrossAttention(text_cross_attention=target, scale=scale), text_cross_attention=clone,
), scale=self.scale,
target[-1], # projection )
clone.replace(
old_module=scaled_dot_product,
new_module=fl.Sum(
scaled_dot_product,
image_cross_attention,
),
)
super().__init__(
clone,
) )
self.ensure_find(fl.Attention).insert_after_type(Distribute, SetQueryProjection())
@property @property
def image_cross_attention(self) -> ImageCrossAttention: def image_cross_attention(self) -> ImageCrossAttention:
@ -307,10 +315,11 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
@property @property
def scale(self) -> float: def scale(self) -> float:
return self.image_cross_attention.scale return self._scale
@scale.setter @scale.setter
def scale(self, value: float) -> None: def scale(self, value: float) -> None:
self._scale = value
self.image_cross_attention.scale = value self.image_cross_attention.scale = value
def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None: def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None: