mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
improve CrossAttentionAdapter test
This commit is contained in:
parent
dc2c3e0163
commit
b69dbc4e5c
|
@ -1,5 +1,5 @@
|
|||
import refiners.fluxion.layers as fl
|
||||
from refiners.foundationals.latent_diffusion.image_prompt import CrossAttentionAdapter
|
||||
from refiners.foundationals.latent_diffusion.image_prompt import CrossAttentionAdapter, InjectionPoint
|
||||
|
||||
|
||||
def test_cross_attention_adapter() -> None:
|
||||
|
@ -7,8 +7,23 @@ def test_cross_attention_adapter() -> None:
|
|||
adapter = CrossAttentionAdapter(base.Attention).inject()
|
||||
|
||||
assert list(base) == [adapter]
|
||||
assert len(list(adapter.layers(fl.Linear))) == 6
|
||||
assert len(list(base.layers(fl.Linear))) == 6
|
||||
|
||||
injection_points = list(adapter.layers(InjectionPoint))
|
||||
assert len(injection_points) == 4
|
||||
for ip in injection_points:
|
||||
assert len(ip) == 1
|
||||
assert isinstance(ip[0], fl.Linear)
|
||||
|
||||
adapter.eject()
|
||||
|
||||
assert len(base) == 1
|
||||
assert isinstance(base[0], fl.Attention)
|
||||
assert len(list(adapter.layers(fl.Linear))) == 2
|
||||
assert len(list(base.layers(fl.Linear))) == 4
|
||||
|
||||
injection_points = list(adapter.layers(InjectionPoint))
|
||||
assert len(injection_points) == 4
|
||||
for ip in injection_points:
|
||||
assert len(ip) == 0
|
||||
|
|
Loading…
Reference in a new issue