2023-08-31 08:40:01 +00:00
|
|
|
import pytest
|
|
|
|
|
2023-12-29 09:59:51 +00:00
|
|
|
from refiners.fluxion.utils import no_grad
|
2023-08-31 08:40:01 +00:00
|
|
|
from refiners.foundationals.latent_diffusion import SD1UNet
|
2023-12-11 10:46:38 +00:00
|
|
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
|
2023-09-01 10:02:47 +00:00
|
|
|
from refiners.foundationals.latent_diffusion.reference_only_control import (
|
2023-08-31 08:40:01 +00:00
|
|
|
ReferenceOnlyControlAdapter,
|
2023-09-01 10:02:47 +00:00
|
|
|
SaveLayerNormAdapter,
|
|
|
|
SelfAttentionInjectionAdapter,
|
2023-08-31 08:40:01 +00:00
|
|
|
SelfAttentionInjectionPassthrough,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-12-29 09:59:51 +00:00
|
|
|
@no_grad()
|
2023-09-14 08:40:24 +00:00
|
|
|
def test_refonly_inject_eject() -> None:
|
2023-08-31 15:22:57 +00:00
|
|
|
unet = SD1UNet(in_channels=9)
|
2023-09-14 08:40:24 +00:00
|
|
|
adapter = ReferenceOnlyControlAdapter(unet)
|
2023-08-31 08:40:01 +00:00
|
|
|
|
|
|
|
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
|
|
|
|
assert nb_cross_attention_blocks > 0
|
|
|
|
|
|
|
|
assert unet.parent is None
|
|
|
|
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
|
|
|
|
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
|
2023-09-01 10:02:47 +00:00
|
|
|
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0
|
2023-08-31 08:40:01 +00:00
|
|
|
|
|
|
|
with pytest.raises(AssertionError) as exc:
|
2023-09-14 08:40:24 +00:00
|
|
|
adapter.eject()
|
2023-08-31 08:40:01 +00:00
|
|
|
assert "not the first element" in str(exc.value)
|
|
|
|
|
2023-09-14 08:40:24 +00:00
|
|
|
adapter.inject()
|
2023-08-31 08:40:01 +00:00
|
|
|
|
2023-09-14 08:40:24 +00:00
|
|
|
assert unet.parent == adapter
|
2023-08-31 08:40:01 +00:00
|
|
|
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1
|
|
|
|
assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks
|
2023-09-01 10:02:47 +00:00
|
|
|
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == nb_cross_attention_blocks
|
2023-08-31 08:40:01 +00:00
|
|
|
|
|
|
|
with pytest.raises(AssertionError) as exc:
|
2023-09-14 08:40:24 +00:00
|
|
|
adapter.inject()
|
2023-08-31 08:40:01 +00:00
|
|
|
assert "already injected" in str(exc.value)
|
|
|
|
|
2023-09-14 08:40:24 +00:00
|
|
|
adapter.eject()
|
2023-08-31 08:40:01 +00:00
|
|
|
|
|
|
|
assert unet.parent is None
|
|
|
|
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
|
|
|
|
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
|
2023-09-01 10:02:47 +00:00
|
|
|
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0
|