refiners/tests/adapters/test_reference_only_control.py

48 lines
1.7 KiB
Python
Raw Permalink Normal View History

import pytest
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
from refiners.foundationals.latent_diffusion.reference_only_control import (
ReferenceOnlyControlAdapter,
SaveLayerNormAdapter,
SelfAttentionInjectionAdapter,
SelfAttentionInjectionPassthrough,
)
@no_grad()
2023-09-14 08:40:24 +00:00
def test_refonly_inject_eject() -> None:
unet = SD1UNet(in_channels=9)
2023-09-14 08:40:24 +00:00
adapter = ReferenceOnlyControlAdapter(unet)
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
assert nb_cross_attention_blocks > 0
assert unet.parent is None
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0
with pytest.raises(AssertionError) as exc:
2023-09-14 08:40:24 +00:00
adapter.eject()
assert "not the first element" in str(exc.value)
2023-09-14 08:40:24 +00:00
adapter.inject()
2023-09-14 08:40:24 +00:00
assert unet.parent == adapter
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1
assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == nb_cross_attention_blocks
with pytest.raises(AssertionError) as exc:
2023-09-14 08:40:24 +00:00
adapter.inject()
assert "already injected" in str(exc.value)
2023-09-14 08:40:24 +00:00
adapter.eject()
assert unet.parent is None
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0