mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
add extra tests for Chain
This commit is contained in:
parent
bca50b71f2
commit
f43a530254
|
@ -21,6 +21,13 @@ def test_chain_find() -> None:
|
||||||
assert chain.find(fl.Conv2d) is None
|
assert chain.find(fl.Conv2d) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_getitem_accessor() -> None:
|
||||||
|
chain = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
|
||||||
|
assert chain["Linear_2"] == chain.Linear_2
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
chain["Linear_3"]
|
||||||
|
|
||||||
|
|
||||||
def test_chain_find_parent():
|
def test_chain_find_parent():
|
||||||
chain = fl.Chain(fl.Chain(fl.Linear(1, 1)))
|
chain = fl.Chain(fl.Chain(fl.Linear(1, 1)))
|
||||||
|
|
||||||
|
@ -48,6 +55,30 @@ def test_chain_slice() -> None:
|
||||||
assert chain[:-1](x).shape == (1, 1)
|
assert chain[:-1](x).shape == (1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_walk() -> None:
|
||||||
|
chain = fl.Chain(
|
||||||
|
fl.Sum(
|
||||||
|
fl.Chain(fl.Linear(1, 1)),
|
||||||
|
fl.Linear(1, 1),
|
||||||
|
),
|
||||||
|
fl.Chain(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert list(chain.walk()) == [(chain.Sum, chain), (chain.Chain, chain)]
|
||||||
|
assert list(chain.walk(fl.Linear)) == [
|
||||||
|
(chain.Sum.Chain.Linear, chain.Sum.Chain),
|
||||||
|
(chain.Sum.Linear, chain.Sum),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert list(chain.walk(recurse=True)) == [
|
||||||
|
(chain.Sum, chain),
|
||||||
|
(chain.Sum.Chain, chain.Sum),
|
||||||
|
(chain.Sum.Chain.Linear, chain.Sum.Chain),
|
||||||
|
(chain.Sum.Linear, chain.Sum),
|
||||||
|
(chain.Chain, chain),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_chain_walk_stop_iteration() -> None:
|
def test_chain_walk_stop_iteration() -> None:
|
||||||
chain = fl.Chain(
|
chain = fl.Chain(
|
||||||
fl.Sum(
|
fl.Sum(
|
||||||
|
|
Loading…
Reference in a new issue