write StyleAligned inject/eject tests

This commit is contained in:
Laurent 2024-02-15 14:11:04 +00:00 committed by Laureηt
parent 2a3e353f04
commit 60c0780fe7

View file

@ -0,0 +1,74 @@
import pytest
from refiners.fluxion.layers.attentions import SelfAttention
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
from refiners.foundationals.latent_diffusion.style_aligned import (
SharedSelfAttentionAdapter,
StyleAligned,
StyleAlignedAdapter,
)
@pytest.fixture(scope="module", params=[1.0, 100.0])
def scale(request: pytest.FixtureRequest) -> float:
return request.param
@pytest.fixture(scope="module", params=[SD1UNet, SDXLUNet])
def unet(request: pytest.FixtureRequest) -> SD1UNet | SDXLUNet:
return request.param(in_channels=4)
@pytest.fixture(scope="module")
def self_attention() -> SelfAttention:
return SelfAttention(embedding_dim=100)
@pytest.fixture(scope="module")
def adapter_SAA(unet: SD1UNet | SDXLUNet, scale: float) -> StyleAlignedAdapter[SD1UNet | SDXLUNet]:
return StyleAlignedAdapter(target=unet, scale=scale)
@pytest.fixture(scope="module")
def adapter_SSA(self_attention: SelfAttention, scale: float) -> SharedSelfAttentionAdapter:
return SharedSelfAttentionAdapter(target=self_attention, scale=scale)
def test_inject_eject_SharedSelfAttentionAdapter(
self_attention: SD1UNet | SDXLUNet, adapter_SSA: SharedSelfAttentionAdapter
):
initial_repr = repr(self_attention)
assert self_attention.parent is None
assert self_attention.find(StyleAligned) is None
assert repr(self_attention) == initial_repr
adapter_SSA.inject()
assert self_attention.parent is not None
assert self_attention.find(StyleAligned) is not None
assert repr(self_attention) != initial_repr
adapter_SSA.eject()
assert self_attention.parent is None
assert self_attention.find(StyleAligned) is None
assert repr(self_attention) == initial_repr
def test_inject_eject_StyleAlignedAdapter(
unet: SD1UNet | SDXLUNet, adapter_SAA: StyleAlignedAdapter[SD1UNet | SDXLUNet]
):
initial_repr = repr(unet)
assert unet.parent is None
assert unet.find(SharedSelfAttentionAdapter) is None
assert repr(unet) == initial_repr
adapter_SAA.inject()
assert unet.parent is not None
assert unet.find(SharedSelfAttentionAdapter) is not None
assert repr(unet) != initial_repr
adapter_SAA.eject()
assert unet.parent is None
assert unet.find(SharedSelfAttentionAdapter) is None
assert repr(unet) == initial_repr