mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-25 07:38:45 +00:00
75 lines
2.3 KiB
Python
75 lines
2.3 KiB
Python
import pytest
|
|
|
|
from refiners.fluxion.layers.attentions import SelfAttention
|
|
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
|
from refiners.foundationals.latent_diffusion.style_aligned import (
|
|
SharedSelfAttentionAdapter,
|
|
StyleAligned,
|
|
StyleAlignedAdapter,
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="module", params=[1.0, 100.0])
|
|
def scale(request: pytest.FixtureRequest) -> float:
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture(scope="module", params=[SD1UNet, SDXLUNet])
|
|
def unet(request: pytest.FixtureRequest) -> SD1UNet | SDXLUNet:
|
|
return request.param(in_channels=4)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def self_attention() -> SelfAttention:
|
|
return SelfAttention(embedding_dim=100)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def adapter_SAA(unet: SD1UNet | SDXLUNet, scale: float) -> StyleAlignedAdapter[SD1UNet | SDXLUNet]:
|
|
return StyleAlignedAdapter(target=unet, scale=scale)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def adapter_SSA(self_attention: SelfAttention, scale: float) -> SharedSelfAttentionAdapter:
|
|
return SharedSelfAttentionAdapter(target=self_attention, scale=scale)
|
|
|
|
|
|
def test_inject_eject_SharedSelfAttentionAdapter(
|
|
self_attention: SD1UNet | SDXLUNet, adapter_SSA: SharedSelfAttentionAdapter
|
|
):
|
|
initial_repr = repr(self_attention)
|
|
|
|
assert self_attention.parent is None
|
|
assert self_attention.find(StyleAligned) is None
|
|
assert repr(self_attention) == initial_repr
|
|
|
|
adapter_SSA.inject()
|
|
assert self_attention.parent is not None
|
|
assert self_attention.find(StyleAligned) is not None
|
|
assert repr(self_attention) != initial_repr
|
|
|
|
adapter_SSA.eject()
|
|
assert self_attention.parent is None
|
|
assert self_attention.find(StyleAligned) is None
|
|
assert repr(self_attention) == initial_repr
|
|
|
|
|
|
def test_inject_eject_StyleAlignedAdapter(
|
|
unet: SD1UNet | SDXLUNet, adapter_SAA: StyleAlignedAdapter[SD1UNet | SDXLUNet]
|
|
):
|
|
initial_repr = repr(unet)
|
|
|
|
assert unet.parent is None
|
|
assert unet.find(SharedSelfAttentionAdapter) is None
|
|
assert repr(unet) == initial_repr
|
|
|
|
adapter_SAA.inject()
|
|
assert unet.parent is not None
|
|
assert unet.find(SharedSelfAttentionAdapter) is not None
|
|
assert repr(unet) != initial_repr
|
|
|
|
adapter_SAA.eject()
|
|
assert unet.parent is None
|
|
assert unet.find(SharedSelfAttentionAdapter) is None
|
|
assert repr(unet) == initial_repr
|