mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 14:18:46 +00:00
81 lines
2 KiB
Python
81 lines
2 KiB
Python
import pytest
|
|
|
|
from refiners.fluxion.adapters.adapter import Adapter
|
|
from refiners.fluxion.layers import Chain, Linear
|
|
|
|
|
|
class DummyLinearAdapter(Chain, Adapter[Linear]):
|
|
def __init__(self, target: Linear):
|
|
with self.setup_adapter(target):
|
|
super().__init__(target)
|
|
|
|
|
|
class DummyChainAdapter(Chain, Adapter[Chain]):
|
|
def __init__(self, target: Chain):
|
|
with self.setup_adapter(target):
|
|
super().__init__(target)
|
|
|
|
|
|
@pytest.fixture
|
|
def chain() -> Chain:
|
|
return Chain(Chain(Linear(2, 2)))
|
|
|
|
|
|
def test_weighted_module_adapter_insertion(chain: Chain):
|
|
parent = chain.Chain
|
|
adaptee = parent.Linear
|
|
|
|
adapter = DummyLinearAdapter(adaptee).inject(parent)
|
|
|
|
assert adapter.parent == parent
|
|
assert adapter in iter(parent)
|
|
assert adaptee not in iter(parent)
|
|
|
|
adapter.eject()
|
|
assert adapter.parent is None
|
|
assert adapter not in iter(parent)
|
|
assert adaptee in iter(parent)
|
|
|
|
|
|
def test_chain_adapter_insertion(chain: Chain):
|
|
parent = chain
|
|
adaptee = parent.Chain
|
|
|
|
adapter = DummyChainAdapter(adaptee)
|
|
assert adaptee.parent == parent
|
|
|
|
adapter.inject()
|
|
assert adapter.parent == parent
|
|
assert adaptee.parent == adapter
|
|
assert adapter in iter(parent)
|
|
assert adaptee not in iter(parent)
|
|
|
|
adapter.eject()
|
|
assert adapter.parent is None
|
|
assert adaptee.parent == parent
|
|
assert adapter not in iter(parent)
|
|
assert adaptee in iter(parent)
|
|
|
|
|
|
def test_weighted_module_adapter_structural_copy(chain: Chain):
|
|
parent = chain.Chain
|
|
adaptee = parent.Linear
|
|
|
|
DummyLinearAdapter(adaptee).inject(parent)
|
|
|
|
clone = chain.structural_copy()
|
|
cloned_adapter = clone.Chain.DummyLinearAdapter
|
|
assert cloned_adapter.parent == clone.Chain
|
|
assert cloned_adapter.target == adaptee
|
|
|
|
|
|
def test_chain_adapter_structural_copy(chain: Chain):
|
|
# Chain adapters cannot be copied by default.
|
|
adapter = DummyChainAdapter(chain.Chain).inject()
|
|
|
|
with pytest.raises(RuntimeError):
|
|
chain.structural_copy()
|
|
|
|
adapter.eject()
|
|
chain.structural_copy()
|