add extra tests for Chain

This commit is contained in:
Pierre Chapuis 2024-01-29 17:56:25 +01:00 committed by Cédric Deltheil
parent bca50b71f2
commit f43a530254

View file

@ -21,6 +21,13 @@ def test_chain_find() -> None:
assert chain.find(fl.Conv2d) is None assert chain.find(fl.Conv2d) is None
def test_chain_getitem_accessor() -> None:
chain = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
assert chain["Linear_2"] == chain.Linear_2
with pytest.raises(KeyError):
chain["Linear_3"]
def test_chain_find_parent(): def test_chain_find_parent():
chain = fl.Chain(fl.Chain(fl.Linear(1, 1))) chain = fl.Chain(fl.Chain(fl.Linear(1, 1)))
@ -48,6 +55,30 @@ def test_chain_slice() -> None:
assert chain[:-1](x).shape == (1, 1) assert chain[:-1](x).shape == (1, 1)
def test_chain_walk() -> None:
chain = fl.Chain(
fl.Sum(
fl.Chain(fl.Linear(1, 1)),
fl.Linear(1, 1),
),
fl.Chain(),
)
assert list(chain.walk()) == [(chain.Sum, chain), (chain.Chain, chain)]
assert list(chain.walk(fl.Linear)) == [
(chain.Sum.Chain.Linear, chain.Sum.Chain),
(chain.Sum.Linear, chain.Sum),
]
assert list(chain.walk(recurse=True)) == [
(chain.Sum, chain),
(chain.Sum.Chain, chain.Sum),
(chain.Sum.Chain.Linear, chain.Sum.Chain),
(chain.Sum.Linear, chain.Sum),
(chain.Chain, chain),
]
def test_chain_walk_stop_iteration() -> None: def test_chain_walk_stop_iteration() -> None:
chain = fl.Chain( chain = fl.Chain(
fl.Sum( fl.Sum(