mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
write StyleAligned
inject/eject tests
This commit is contained in:
parent
2a3e353f04
commit
60c0780fe7
74
tests/adapters/test_style_aligned_adapter.py
Normal file
74
tests/adapters/test_style_aligned_adapter.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from refiners.fluxion.layers.attentions import SelfAttention
|
||||||
|
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
||||||
|
from refiners.foundationals.latent_diffusion.style_aligned import (
|
||||||
|
SharedSelfAttentionAdapter,
|
||||||
|
StyleAligned,
|
||||||
|
StyleAlignedAdapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=[1.0, 100.0])
|
||||||
|
def scale(request: pytest.FixtureRequest) -> float:
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=[SD1UNet, SDXLUNet])
|
||||||
|
def unet(request: pytest.FixtureRequest) -> SD1UNet | SDXLUNet:
|
||||||
|
return request.param(in_channels=4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def self_attention() -> SelfAttention:
|
||||||
|
return SelfAttention(embedding_dim=100)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def adapter_SAA(unet: SD1UNet | SDXLUNet, scale: float) -> StyleAlignedAdapter[SD1UNet | SDXLUNet]:
|
||||||
|
return StyleAlignedAdapter(target=unet, scale=scale)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def adapter_SSA(self_attention: SelfAttention, scale: float) -> SharedSelfAttentionAdapter:
|
||||||
|
return SharedSelfAttentionAdapter(target=self_attention, scale=scale)
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_eject_SharedSelfAttentionAdapter(
|
||||||
|
self_attention: SD1UNet | SDXLUNet, adapter_SSA: SharedSelfAttentionAdapter
|
||||||
|
):
|
||||||
|
initial_repr = repr(self_attention)
|
||||||
|
|
||||||
|
assert self_attention.parent is None
|
||||||
|
assert self_attention.find(StyleAligned) is None
|
||||||
|
assert repr(self_attention) == initial_repr
|
||||||
|
|
||||||
|
adapter_SSA.inject()
|
||||||
|
assert self_attention.parent is not None
|
||||||
|
assert self_attention.find(StyleAligned) is not None
|
||||||
|
assert repr(self_attention) != initial_repr
|
||||||
|
|
||||||
|
adapter_SSA.eject()
|
||||||
|
assert self_attention.parent is None
|
||||||
|
assert self_attention.find(StyleAligned) is None
|
||||||
|
assert repr(self_attention) == initial_repr
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_eject_StyleAlignedAdapter(
|
||||||
|
unet: SD1UNet | SDXLUNet, adapter_SAA: StyleAlignedAdapter[SD1UNet | SDXLUNet]
|
||||||
|
):
|
||||||
|
initial_repr = repr(unet)
|
||||||
|
|
||||||
|
assert unet.parent is None
|
||||||
|
assert unet.find(SharedSelfAttentionAdapter) is None
|
||||||
|
assert repr(unet) == initial_repr
|
||||||
|
|
||||||
|
adapter_SAA.inject()
|
||||||
|
assert unet.parent is not None
|
||||||
|
assert unet.find(SharedSelfAttentionAdapter) is not None
|
||||||
|
assert repr(unet) != initial_repr
|
||||||
|
|
||||||
|
adapter_SAA.eject()
|
||||||
|
assert unet.parent is None
|
||||||
|
assert unet.find(SharedSelfAttentionAdapter) is None
|
||||||
|
assert repr(unet) == initial_repr
|
Loading…
Reference in a new issue