2023-09-08 16:51:54 +00:00
|
|
|
import refiners.fluxion.layers as fl
|
2023-09-12 09:31:12 +00:00
|
|
|
from refiners.foundationals.latent_diffusion.image_prompt import CrossAttentionAdapter, InjectionPoint
|
2023-09-08 16:51:54 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_cross_attention_adapter() -> None:
|
|
|
|
base = fl.Chain(fl.Attention(embedding_dim=4))
|
|
|
|
adapter = CrossAttentionAdapter(base.Attention).inject()
|
|
|
|
|
|
|
|
assert list(base) == [adapter]
|
2023-09-12 09:31:12 +00:00
|
|
|
assert len(list(adapter.layers(fl.Linear))) == 6
|
|
|
|
assert len(list(base.layers(fl.Linear))) == 6
|
|
|
|
|
|
|
|
injection_points = list(adapter.layers(InjectionPoint))
|
|
|
|
assert len(injection_points) == 4
|
|
|
|
for ip in injection_points:
|
|
|
|
assert len(ip) == 1
|
|
|
|
assert isinstance(ip[0], fl.Linear)
|
2023-09-08 16:51:54 +00:00
|
|
|
|
|
|
|
adapter.eject()
|
|
|
|
|
|
|
|
assert len(base) == 1
|
|
|
|
assert isinstance(base[0], fl.Attention)
|
2023-09-12 09:31:12 +00:00
|
|
|
assert len(list(adapter.layers(fl.Linear))) == 2
|
|
|
|
assert len(list(base.layers(fl.Linear))) == 4
|
|
|
|
|
|
|
|
injection_points = list(adapter.layers(InjectionPoint))
|
|
|
|
assert len(injection_points) == 4
|
|
|
|
for ip in injection_points:
|
|
|
|
assert len(ip) == 0
|