From 60c0780fe75988055b689aa6b6657e5c520c3ee0 Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 15 Feb 2024 14:11:04 +0000 Subject: [PATCH] write `StyleAligned` inject/eject tests --- tests/adapters/test_style_aligned_adapter.py | 74 ++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 tests/adapters/test_style_aligned_adapter.py diff --git a/tests/adapters/test_style_aligned_adapter.py b/tests/adapters/test_style_aligned_adapter.py new file mode 100644 index 0000000..3ebafc2 --- /dev/null +++ b/tests/adapters/test_style_aligned_adapter.py @@ -0,0 +1,74 @@ +import pytest + +from refiners.fluxion.layers.attentions import SelfAttention +from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet +from refiners.foundationals.latent_diffusion.style_aligned import ( + SharedSelfAttentionAdapter, + StyleAligned, + StyleAlignedAdapter, +) + + +@pytest.fixture(scope="module", params=[1.0, 100.0]) +def scale(request: pytest.FixtureRequest) -> float: + return request.param + + +@pytest.fixture(scope="module", params=[SD1UNet, SDXLUNet]) +def unet(request: pytest.FixtureRequest) -> SD1UNet | SDXLUNet: + return request.param(in_channels=4) + + +@pytest.fixture(scope="module") +def self_attention() -> SelfAttention: + return SelfAttention(embedding_dim=100) + + +@pytest.fixture(scope="module") +def adapter_SAA(unet: SD1UNet | SDXLUNet, scale: float) -> StyleAlignedAdapter[SD1UNet | SDXLUNet]: + return StyleAlignedAdapter(target=unet, scale=scale) + + +@pytest.fixture(scope="module") +def adapter_SSA(self_attention: SelfAttention, scale: float) -> SharedSelfAttentionAdapter: + return SharedSelfAttentionAdapter(target=self_attention, scale=scale) + + +def test_inject_eject_SharedSelfAttentionAdapter( + self_attention: SD1UNet | SDXLUNet, adapter_SSA: SharedSelfAttentionAdapter +): + initial_repr = repr(self_attention) + + assert self_attention.parent is None + assert self_attention.find(StyleAligned) is None + assert repr(self_attention) == initial_repr + + adapter_SSA.inject() + assert self_attention.parent is not None + assert self_attention.find(StyleAligned) is not None + assert repr(self_attention) != initial_repr + + adapter_SSA.eject() + assert self_attention.parent is None + assert self_attention.find(StyleAligned) is None + assert repr(self_attention) == initial_repr + + +def test_inject_eject_StyleAlignedAdapter( + unet: SD1UNet | SDXLUNet, adapter_SAA: StyleAlignedAdapter[SD1UNet | SDXLUNet] +): + initial_repr = repr(unet) + + assert unet.parent is None + assert unet.find(SharedSelfAttentionAdapter) is None + assert repr(unet) == initial_repr + + adapter_SAA.inject() + assert unet.parent is not None + assert unet.find(SharedSelfAttentionAdapter) is not None + assert repr(unet) != initial_repr + + adapter_SAA.eject() + assert unet.parent is None + assert unet.find(SharedSelfAttentionAdapter) is None + assert repr(unet) == initial_repr