refiners/tests/adapters/test_adapter.py

83 lines
2.1 KiB
Python
Raw Normal View History

2023-08-04 13:28:41 +00:00
import pytest
from refiners.adapters.adapter import Adapter
from refiners.fluxion.layers import Chain, Linear
class DummyLinearAdapter(Chain, Adapter[Linear]):
def __init__(self, target: Linear):
with self.setup_adapter(target):
super().__init__(target)
class DummyChainAdapter(Chain, Adapter[Chain]):
def __init__(self, target: Chain):
with self.setup_adapter(target):
super().__init__(target)
@pytest.fixture
def chain() -> Chain:
return Chain(Chain(Linear(2, 2)))
def test_weighted_module_adapter_insertion(chain: Chain):
parent = chain.Chain
adaptee = parent.Linear
adapter = DummyLinearAdapter(adaptee)
adapter.inject(parent)
assert adapter.parent == parent
assert adapter in iter(parent)
assert adaptee not in iter(parent)
adapter.eject()
assert adapter.parent is None
assert adapter not in iter(parent)
assert adaptee in iter(parent)
def test_chain_adapter_insertion(chain: Chain):
parent = chain
adaptee = parent.Chain
adapter = DummyChainAdapter(adaptee)
assert adaptee.parent == parent
adapter.inject()
assert adapter.parent == parent
assert adaptee.parent == adapter
assert adapter in iter(parent)
assert adaptee not in iter(parent)
adapter.eject()
assert adapter.parent is None
assert adaptee.parent == parent
assert adapter not in iter(parent)
assert adaptee in iter(parent)
def test_weighted_module_adapter_structural_copy(chain: Chain):
parent = chain.Chain
adaptee = parent.Linear
adapter = DummyLinearAdapter(adaptee)
adapter.inject(parent)
clone = chain.structural_copy()
cloned_adapter = clone.Chain.DummyLinearAdapter
assert cloned_adapter.parent == clone.Chain
assert cloned_adapter.target == adaptee
def test_chain_adapter_structural_copy(chain: Chain):
# Chain adapters cannot be copied by default.
adapter = DummyChainAdapter(chain.Chain)
adapter.inject()
with pytest.raises(RuntimeError):
chain.structural_copy()
adapter.eject()
chain.structural_copy()