2023-08-04 13:28:41 +00:00
|
|
|
import pytest
|
2023-12-11 10:46:38 +00:00
|
|
|
|
2023-09-01 14:50:41 +00:00
|
|
|
from refiners.fluxion.adapters.adapter import Adapter
|
2023-08-04 13:28:41 +00:00
|
|
|
from refiners.fluxion.layers import Chain, Linear
|
|
|
|
|
|
|
|
|
|
|
|
class DummyLinearAdapter(Chain, Adapter[Linear]):
|
|
|
|
def __init__(self, target: Linear):
|
|
|
|
with self.setup_adapter(target):
|
|
|
|
super().__init__(target)
|
|
|
|
|
|
|
|
|
|
|
|
class DummyChainAdapter(Chain, Adapter[Chain]):
|
|
|
|
def __init__(self, target: Chain):
|
|
|
|
with self.setup_adapter(target):
|
|
|
|
super().__init__(target)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def chain() -> Chain:
|
|
|
|
return Chain(Chain(Linear(2, 2)))
|
|
|
|
|
|
|
|
|
|
|
|
def test_weighted_module_adapter_insertion(chain: Chain):
|
2024-02-05 16:10:05 +00:00
|
|
|
parent = chain.layer("Chain", Chain)
|
|
|
|
adaptee = parent.layer("Linear", Linear)
|
2023-08-04 13:28:41 +00:00
|
|
|
|
2023-08-31 08:40:01 +00:00
|
|
|
adapter = DummyLinearAdapter(adaptee).inject(parent)
|
2023-08-04 13:28:41 +00:00
|
|
|
|
|
|
|
assert adapter.parent == parent
|
|
|
|
assert adapter in iter(parent)
|
|
|
|
assert adaptee not in iter(parent)
|
|
|
|
|
|
|
|
adapter.eject()
|
|
|
|
assert adapter.parent is None
|
|
|
|
assert adapter not in iter(parent)
|
|
|
|
assert adaptee in iter(parent)
|
|
|
|
|
|
|
|
|
|
|
|
def test_chain_adapter_insertion(chain: Chain):
|
|
|
|
parent = chain
|
2024-02-05 16:10:05 +00:00
|
|
|
adaptee = parent.layer("Chain", Chain)
|
2023-08-04 13:28:41 +00:00
|
|
|
|
|
|
|
adapter = DummyChainAdapter(adaptee)
|
|
|
|
assert adaptee.parent == parent
|
|
|
|
|
|
|
|
adapter.inject()
|
|
|
|
assert adapter.parent == parent
|
|
|
|
assert adaptee.parent == adapter
|
|
|
|
assert adapter in iter(parent)
|
|
|
|
assert adaptee not in iter(parent)
|
|
|
|
|
|
|
|
adapter.eject()
|
|
|
|
assert adapter.parent is None
|
|
|
|
assert adaptee.parent == parent
|
|
|
|
assert adapter not in iter(parent)
|
|
|
|
assert adaptee in iter(parent)
|
|
|
|
|
|
|
|
|
|
|
|
def test_weighted_module_adapter_structural_copy(chain: Chain):
|
2024-02-05 16:10:05 +00:00
|
|
|
parent = chain.layer("Chain", Chain)
|
|
|
|
adaptee = parent.layer("Linear", Linear)
|
2023-08-04 13:28:41 +00:00
|
|
|
|
2023-08-31 08:40:01 +00:00
|
|
|
DummyLinearAdapter(adaptee).inject(parent)
|
2023-08-04 13:28:41 +00:00
|
|
|
|
|
|
|
clone = chain.structural_copy()
|
2024-02-05 16:10:05 +00:00
|
|
|
cloned_adapter = clone.layer(("Chain", "DummyLinearAdapter"), DummyLinearAdapter)
|
2023-08-04 13:28:41 +00:00
|
|
|
assert cloned_adapter.parent == clone.Chain
|
|
|
|
assert cloned_adapter.target == adaptee
|
|
|
|
|
|
|
|
|
|
|
|
def test_chain_adapter_structural_copy(chain: Chain):
|
|
|
|
# Chain adapters cannot be copied by default.
|
2024-02-05 16:10:05 +00:00
|
|
|
adapter = DummyChainAdapter(chain.layer("Chain", Chain)).inject()
|
2023-08-04 13:28:41 +00:00
|
|
|
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
chain.structural_copy()
|
|
|
|
|
|
|
|
adapter.eject()
|
|
|
|
chain.structural_copy()
|