diff --git a/tests/fluxion/layers/test_chain.py b/tests/fluxion/layers/test_chain.py index 5aab79b..fea31ab 100644 --- a/tests/fluxion/layers/test_chain.py +++ b/tests/fluxion/layers/test_chain.py @@ -1,14 +1,32 @@ +import pytest import torch import refiners.fluxion.layers as fl +from refiners.fluxion.context import Contexts + + +class ContextChain(fl.Chain): + def init_context(self) -> Contexts: + return {"foo": {"bar": [42]}} + + +def module_keys(chain: fl.Chain) -> list[str]: + return list(chain._modules.keys()) # type: ignore[reportPrivateUsage] def test_chain_find() -> None: chain = fl.Chain(fl.Linear(1, 1)) - assert isinstance(chain.find(fl.Linear), fl.Linear) + assert chain.find(fl.Linear) == chain.Linear assert chain.find(fl.Conv2d) is None +def test_chain_find_parent(): + chain = fl.Chain(fl.Chain(fl.Linear(1, 1))) + + assert chain.find_parent(chain.Chain.Linear) == chain.Chain + assert chain.find_parent(fl.Linear(1, 1)) is None + + def test_chain_slice() -> None: chain = fl.Chain( fl.Linear(1, 1), @@ -59,21 +77,91 @@ def test_chain_layers() -> None: assert len(list(chain.layers(fl.Chain, recurse=True))) == 4 +def test_chain_insert() -> None: + parent = ContextChain(fl.Linear(1, 1), fl.Linear(1, 1)) + child = fl.Chain() + parent.insert(1, child) + + assert module_keys(parent) == ["Linear_1", "Chain", "Linear_2"] + assert child.parent == parent + assert child.provider.get_context("foo") == {"bar": [42]} + + +def test_chain_insert_negative() -> None: + parent = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)) + child = fl.Chain() + parent.insert(-2, child) + + assert module_keys(parent) == ["Linear_1", "Chain", "Linear_2"] + + +def test_chain_insert_after_type() -> None: + child = fl.Chain() + + parent_1 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)) + parent_1.insert_after_type(fl.Linear, child) + assert module_keys(parent_1) == ["Linear_1", "Chain", "Linear_2"] + + parent_2 = fl.Chain(fl.Conv2d(1, 1, 1), fl.Linear(1, 1)) + parent_2.insert_after_type(fl.Linear, child) + assert module_keys(parent_2) == ["Conv2d", "Linear", "Chain"] + + +def test_chain_insert_overflow() -> None: + # This behaves as insert() in lists in Python. + + child = fl.Chain() + + parent_1 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)) + parent_1.insert(42, child) + assert module_keys(parent_1) == ["Linear_1", "Linear_2", "Chain"] + + parent_2 = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)) + parent_2.insert(-42, child) + assert module_keys(parent_2) == ["Chain", "Linear_1", "Linear_2"] + + +def test_chain_append() -> None: + child = fl.Chain() + + parent = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)) + parent.append(child) + assert module_keys(parent) == ["Linear_1", "Linear_2", "Chain"] + + +def test_chain_pop() -> None: + chain = fl.Chain(fl.Linear(1, 1), fl.Conv2d(1, 1, 1), fl.Chain()) + + with pytest.raises(IndexError): + chain.pop(3) + + with pytest.raises(IndexError): + chain.pop(-4) + + assert module_keys(chain) == ["Linear", "Conv2d", "Chain"] + chain.pop(1) + assert module_keys(chain) == ["Linear", "Chain"] + + chain.pop(-2) + assert module_keys(chain) == ["Chain"] + + def test_chain_remove() -> None: - chain = fl.Chain( - fl.Linear(1, 1), + child = fl.Linear(1, 1) + + parent = fl.Chain( fl.Linear(1, 1), + child, fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)), ) - assert len(chain) == 3 - assert "Linear_1" in chain._modules - assert "Linear" not in chain._modules + assert child in parent + assert module_keys(parent) == ["Linear_1", "Linear_2", "Chain"] - chain.remove(chain.Linear_2) - assert len(chain) == 2 - assert "Linear" in chain._modules - assert "Linear_1" not in chain._modules + parent.remove(child) + + assert child not in parent + assert module_keys(parent) == ["Linear", "Chain"] def test_chain_replace() -> None: