mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add extra tests for Chain
This commit is contained in:
parent
bca50b71f2
commit
f43a530254
|
@ -21,6 +21,13 @@ def test_chain_find() -> None:
|
|||
assert chain.find(fl.Conv2d) is None
|
||||
|
||||
|
||||
def test_chain_getitem_accessor() -> None:
|
||||
chain = fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1))
|
||||
assert chain["Linear_2"] == chain.Linear_2
|
||||
with pytest.raises(KeyError):
|
||||
chain["Linear_3"]
|
||||
|
||||
|
||||
def test_chain_find_parent():
|
||||
chain = fl.Chain(fl.Chain(fl.Linear(1, 1)))
|
||||
|
||||
|
@ -48,6 +55,30 @@ def test_chain_slice() -> None:
|
|||
assert chain[:-1](x).shape == (1, 1)
|
||||
|
||||
|
||||
def test_chain_walk() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Sum(
|
||||
fl.Chain(fl.Linear(1, 1)),
|
||||
fl.Linear(1, 1),
|
||||
),
|
||||
fl.Chain(),
|
||||
)
|
||||
|
||||
assert list(chain.walk()) == [(chain.Sum, chain), (chain.Chain, chain)]
|
||||
assert list(chain.walk(fl.Linear)) == [
|
||||
(chain.Sum.Chain.Linear, chain.Sum.Chain),
|
||||
(chain.Sum.Linear, chain.Sum),
|
||||
]
|
||||
|
||||
assert list(chain.walk(recurse=True)) == [
|
||||
(chain.Sum, chain),
|
||||
(chain.Sum.Chain, chain.Sum),
|
||||
(chain.Sum.Chain.Linear, chain.Sum.Chain),
|
||||
(chain.Sum.Linear, chain.Sum),
|
||||
(chain.Chain, chain),
|
||||
]
|
||||
|
||||
|
||||
def test_chain_walk_stop_iteration() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Sum(
|
||||
|
|
Loading…
Reference in a new issue