add a test for StopIteration in walk

This commit is contained in:
Pierre Chapuis 2023-08-23 10:27:32 +02:00
parent dec0d64432
commit a0c70ba7aa

View file

@ -86,6 +86,25 @@ def test_chain_slice() -> None:
assert chain[:-1](x).shape == (1, 1)
def test_chain_walk_stop_iteration() -> None:
chain = fl.Chain(
fl.Sum(
fl.Chain(fl.Linear(in_features=1, out_features=1)),
fl.Linear(in_features=1, out_features=1),
),
fl.Chain(),
fl.Linear(in_features=1, out_features=1),
)
def predicate(m: fl.Module, p: fl.Chain) -> bool:
if isinstance(m, fl.Sum):
raise StopIteration
return isinstance(m, fl.Linear)
assert len(list(chain.walk(fl.Linear))) == 3
assert len(list(chain.walk(predicate))) == 1
def test_chain_layers() -> None:
chain = fl.Chain(
fl.Chain(fl.Chain(fl.Chain())),