mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
modify ip_adapter's ImageCrossAttention scale getter and setter
this new version makes it robust in case mulitple Mulitply-s are inside the Chain (e.g. if the Linear layers are LoRA-ified)
This commit is contained in:
parent
7e64ba4011
commit
a0715806d2
|
@ -235,7 +235,7 @@ class PerceiverResampler(fl.Chain):
|
||||||
|
|
||||||
class ImageCrossAttention(fl.Chain):
|
class ImageCrossAttention(fl.Chain):
|
||||||
def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
|
def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
|
||||||
self._scale = scale
|
self._multiply = [fl.Multiply(scale)]
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Distribute(
|
fl.Distribute(
|
||||||
fl.Identity(),
|
fl.Identity(),
|
||||||
|
@ -263,17 +263,20 @@ class ImageCrossAttention(fl.Chain):
|
||||||
ScaledDotProductAttention(
|
ScaledDotProductAttention(
|
||||||
num_heads=text_cross_attention.num_heads, is_causal=text_cross_attention.is_causal
|
num_heads=text_cross_attention.num_heads, is_causal=text_cross_attention.is_causal
|
||||||
),
|
),
|
||||||
fl.Multiply(self.scale),
|
self.multiply,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def multiply(self) -> fl.Multiply:
|
||||||
|
return self._multiply[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scale(self) -> float:
|
def scale(self) -> float:
|
||||||
return self._scale
|
return self.multiply.scale
|
||||||
|
|
||||||
@scale.setter
|
@scale.setter
|
||||||
def scale(self, value: float) -> None:
|
def scale(self, value: float) -> None:
|
||||||
self._scale = value
|
self.multiply.scale = value
|
||||||
self.ensure_find(fl.Multiply).scale = value
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
|
@ -335,7 +338,6 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
|
|
||||||
@scale.setter
|
@scale.setter
|
||||||
def scale(self, value: float) -> None:
|
def scale(self, value: float) -> None:
|
||||||
self._scale = value
|
|
||||||
self.image_cross_attention.scale = value
|
self.image_cross_attention.scale = value
|
||||||
|
|
||||||
def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None:
|
def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None:
|
||||||
|
|
Loading…
Reference in a new issue