mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
simplify even more CrossAttentionAdapter
Following Laurent2916's idea: see #167
This commit is contained in:
parent
3ab8ed2989
commit
deed703617
|
@ -9,7 +9,6 @@ import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
from refiners.fluxion.context import Contexts
|
from refiners.fluxion.context import Contexts
|
||||||
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||||
from refiners.fluxion.layers.chain import Distribute
|
|
||||||
from refiners.fluxion.utils import image_to_tensor, normalize
|
from refiners.fluxion.utils import image_to_tensor, normalize
|
||||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||||
|
|
||||||
|
@ -234,16 +233,12 @@ class PerceiverResampler(fl.Chain):
|
||||||
return {"perceiver_resampler": {"x": None}}
|
return {"perceiver_resampler": {"x": None}}
|
||||||
|
|
||||||
|
|
||||||
class InjectionPoint(fl.Chain):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ImageCrossAttention(fl.Chain):
|
class ImageCrossAttention(fl.Chain):
|
||||||
def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
|
def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
|
||||||
self.scale = scale
|
self._scale = scale
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Distribute(
|
fl.Distribute(
|
||||||
fl.UseContext(context="ip_adapter", key="query_projection"),
|
fl.Identity(),
|
||||||
fl.Chain(
|
fl.Chain(
|
||||||
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
|
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
|
||||||
fl.Linear(
|
fl.Linear(
|
||||||
|
@ -271,10 +266,14 @@ class ImageCrossAttention(fl.Chain):
|
||||||
fl.Multiply(self.scale),
|
fl.Multiply(self.scale),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale(self) -> float:
|
||||||
|
return self._scale
|
||||||
|
|
||||||
class SetQueryProjection(fl.Passthrough):
|
@scale.setter
|
||||||
def __init__(self) -> None:
|
def scale(self, value: float) -> None:
|
||||||
super().__init__(fl.GetArg(index=0), fl.SetContext(context="ip_adapter", key="query_projection"))
|
self._scale = value
|
||||||
|
self.ensure_find(fl.Multiply).scale = value
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
|
@ -283,15 +282,24 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
target: fl.Attention,
|
target: fl.Attention,
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self._scale = scale
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
super().__init__(
|
clone = target.structural_copy()
|
||||||
fl.Sum(
|
scaled_dot_product = clone.ensure_find(ScaledDotProductAttention)
|
||||||
target[:-1], # original text cross attention
|
image_cross_attention = ImageCrossAttention(
|
||||||
ImageCrossAttention(text_cross_attention=target, scale=scale),
|
text_cross_attention=clone,
|
||||||
),
|
scale=self.scale,
|
||||||
target[-1], # projection
|
)
|
||||||
|
clone.replace(
|
||||||
|
old_module=scaled_dot_product,
|
||||||
|
new_module=fl.Sum(
|
||||||
|
scaled_dot_product,
|
||||||
|
image_cross_attention,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
clone,
|
||||||
)
|
)
|
||||||
self.ensure_find(fl.Attention).insert_after_type(Distribute, SetQueryProjection())
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_cross_attention(self) -> ImageCrossAttention:
|
def image_cross_attention(self) -> ImageCrossAttention:
|
||||||
|
@ -307,10 +315,11 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scale(self) -> float:
|
def scale(self) -> float:
|
||||||
return self.image_cross_attention.scale
|
return self._scale
|
||||||
|
|
||||||
@scale.setter
|
@scale.setter
|
||||||
def scale(self, value: float) -> None:
|
def scale(self, value: float) -> None:
|
||||||
|
self._scale = value
|
||||||
self.image_cross_attention.scale = value
|
self.image_cross_attention.scale = value
|
||||||
|
|
||||||
def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None:
|
def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None:
|
||||||
|
|
Loading…
Reference in a new issue