diff --git a/tests/fluxion/layers/test_chain.py b/tests/fluxion/layers/test_chain.py index 0e6360b..d977f06 100644 --- a/tests/fluxion/layers/test_chain.py +++ b/tests/fluxion/layers/test_chain.py @@ -86,6 +86,25 @@ def test_chain_slice() -> None: assert chain[:-1](x).shape == (1, 1) +def test_chain_walk_stop_iteration() -> None: + chain = fl.Chain( + fl.Sum( + fl.Chain(fl.Linear(in_features=1, out_features=1)), + fl.Linear(in_features=1, out_features=1), + ), + fl.Chain(), + fl.Linear(in_features=1, out_features=1), + ) + + def predicate(m: fl.Module, p: fl.Chain) -> bool: + if isinstance(m, fl.Sum): + raise StopIteration + return isinstance(m, fl.Linear) + + assert len(list(chain.walk(fl.Linear))) == 3 + assert len(list(chain.walk(predicate))) == 1 + + def test_chain_layers() -> None: chain = fl.Chain( fl.Chain(fl.Chain(fl.Chain())),