diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index a9c4f16..14d5bde 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -235,7 +235,7 @@ class PerceiverResampler(fl.Chain): class ImageCrossAttention(fl.Chain): def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None: - self._scale = scale + self._multiply = [fl.Multiply(scale)] super().__init__( fl.Distribute( fl.Identity(), @@ -263,17 +263,20 @@ class ImageCrossAttention(fl.Chain): ScaledDotProductAttention( num_heads=text_cross_attention.num_heads, is_causal=text_cross_attention.is_causal ), - fl.Multiply(self.scale), + self.multiply, ) + @property + def multiply(self) -> fl.Multiply: + return self._multiply[0] + @property def scale(self) -> float: - return self._scale + return self.multiply.scale @scale.setter def scale(self, value: float) -> None: - self._scale = value - self.ensure_find(fl.Multiply).scale = value + self.multiply.scale = value class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): @@ -335,7 +338,6 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): @scale.setter def scale(self, value: float) -> None: - self._scale = value self.image_cross_attention.scale = value def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None: