improve CrossAttentionAdapter test

This commit is contained in:
Pierre Chapuis 2023-09-12 11:31:12 +02:00
parent dc2c3e0163
commit b69dbc4e5c

View file

@ -1,5 +1,5 @@
import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.image_prompt import CrossAttentionAdapter
from refiners.foundationals.latent_diffusion.image_prompt import CrossAttentionAdapter, InjectionPoint
def test_cross_attention_adapter() -> None:
@ -7,8 +7,23 @@ def test_cross_attention_adapter() -> None:
adapter = CrossAttentionAdapter(base.Attention).inject()
assert list(base) == [adapter]
assert len(list(adapter.layers(fl.Linear))) == 6
assert len(list(base.layers(fl.Linear))) == 6
injection_points = list(adapter.layers(InjectionPoint))
assert len(injection_points) == 4
for ip in injection_points:
assert len(ip) == 1
assert isinstance(ip[0], fl.Linear)
adapter.eject()
assert len(base) == 1
assert isinstance(base[0], fl.Attention)
assert len(list(adapter.layers(fl.Linear))) == 2
assert len(list(base.layers(fl.Linear))) == 4
injection_points = list(adapter.layers(InjectionPoint))
assert len(injection_points) == 4
for ip in injection_points:
assert len(ip) == 0