refiners/tests/adapters/test_style_aligned_adapter.py
2024-02-15 15:22:47 +01:00

75 lines
2.3 KiB
Python

import pytest
from refiners.fluxion.layers.attentions import SelfAttention
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
from refiners.foundationals.latent_diffusion.style_aligned import (
SharedSelfAttentionAdapter,
StyleAligned,
StyleAlignedAdapter,
)
@pytest.fixture(scope="module", params=[1.0, 100.0])
def scale(request: pytest.FixtureRequest) -> float:
return request.param
@pytest.fixture(scope="module", params=[SD1UNet, SDXLUNet])
def unet(request: pytest.FixtureRequest) -> SD1UNet | SDXLUNet:
return request.param(in_channels=4)
@pytest.fixture(scope="module")
def self_attention() -> SelfAttention:
return SelfAttention(embedding_dim=100)
@pytest.fixture(scope="module")
def adapter_SAA(unet: SD1UNet | SDXLUNet, scale: float) -> StyleAlignedAdapter[SD1UNet | SDXLUNet]:
return StyleAlignedAdapter(target=unet, scale=scale)
@pytest.fixture(scope="module")
def adapter_SSA(self_attention: SelfAttention, scale: float) -> SharedSelfAttentionAdapter:
return SharedSelfAttentionAdapter(target=self_attention, scale=scale)
def test_inject_eject_SharedSelfAttentionAdapter(
self_attention: SD1UNet | SDXLUNet, adapter_SSA: SharedSelfAttentionAdapter
):
initial_repr = repr(self_attention)
assert self_attention.parent is None
assert self_attention.find(StyleAligned) is None
assert repr(self_attention) == initial_repr
adapter_SSA.inject()
assert self_attention.parent is not None
assert self_attention.find(StyleAligned) is not None
assert repr(self_attention) != initial_repr
adapter_SSA.eject()
assert self_attention.parent is None
assert self_attention.find(StyleAligned) is None
assert repr(self_attention) == initial_repr
def test_inject_eject_StyleAlignedAdapter(
unet: SD1UNet | SDXLUNet, adapter_SAA: StyleAlignedAdapter[SD1UNet | SDXLUNet]
):
initial_repr = repr(unet)
assert unet.parent is None
assert unet.find(SharedSelfAttentionAdapter) is None
assert repr(unet) == initial_repr
adapter_SAA.inject()
assert unet.parent is not None
assert unet.find(SharedSelfAttentionAdapter) is not None
assert repr(unet) != initial_repr
adapter_SAA.eject()
assert unet.parent is None
assert unet.find(SharedSelfAttentionAdapter) is None
assert repr(unet) == initial_repr