diff --git a/tests/fluxion/layers/test_chain.py b/tests/fluxion/layers/test_chain.py index 402564c..5e87812 100644 --- a/tests/fluxion/layers/test_chain.py +++ b/tests/fluxion/layers/test_chain.py @@ -21,6 +21,13 @@ def test_chain_find() -> None: assert chain.find(fl.Conv2d) is None +def test_chain_getitem_accessor() -> None: + chain = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)) + assert chain["Linear_2"] == chain.Linear_2 + with pytest.raises(KeyError): + chain["Linear_3"] + + def test_chain_find_parent(): chain = fl.Chain(fl.Chain(fl.Linear(1, 1))) @@ -48,6 +55,30 @@ def test_chain_slice() -> None: assert chain[:-1](x).shape == (1, 1) +def test_chain_walk() -> None: + chain = fl.Chain( + fl.Sum( + fl.Chain(fl.Linear(1, 1)), + fl.Linear(1, 1), + ), + fl.Chain(), + ) + + assert list(chain.walk()) == [(chain.Sum, chain), (chain.Chain, chain)] + assert list(chain.walk(fl.Linear)) == [ + (chain.Sum.Chain.Linear, chain.Sum.Chain), + (chain.Sum.Linear, chain.Sum), + ] + + assert list(chain.walk(recurse=True)) == [ + (chain.Sum, chain), + (chain.Sum.Chain, chain.Sum), + (chain.Sum.Chain.Linear, chain.Sum.Chain), + (chain.Sum.Linear, chain.Sum), + (chain.Chain, chain), + ] + + def test_chain_walk_stop_iteration() -> None: chain = fl.Chain( fl.Sum(