mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-13 00:28:14 +00:00
15 lines
411 B
Python
15 lines
411 B
Python
import refiners.fluxion.layers as fl
|
|
from refiners.foundationals.latent_diffusion.image_prompt import CrossAttentionAdapter
|
|
|
|
|
|
def test_cross_attention_adapter() -> None:
|
|
base = fl.Chain(fl.Attention(embedding_dim=4))
|
|
adapter = CrossAttentionAdapter(base.Attention).inject()
|
|
|
|
assert list(base) == [adapter]
|
|
|
|
adapter.eject()
|
|
|
|
assert len(base) == 1
|
|
assert isinstance(base[0], fl.Attention)
|