mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
add a test for StopIteration in walk
This commit is contained in:
parent
dec0d64432
commit
a0c70ba7aa
|
@ -86,6 +86,25 @@ def test_chain_slice() -> None:
|
||||||
assert chain[:-1](x).shape == (1, 1)
|
assert chain[:-1](x).shape == (1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_walk_stop_iteration() -> None:
|
||||||
|
chain = fl.Chain(
|
||||||
|
fl.Sum(
|
||||||
|
fl.Chain(fl.Linear(in_features=1, out_features=1)),
|
||||||
|
fl.Linear(in_features=1, out_features=1),
|
||||||
|
),
|
||||||
|
fl.Chain(),
|
||||||
|
fl.Linear(in_features=1, out_features=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def predicate(m: fl.Module, p: fl.Chain) -> bool:
|
||||||
|
if isinstance(m, fl.Sum):
|
||||||
|
raise StopIteration
|
||||||
|
return isinstance(m, fl.Linear)
|
||||||
|
|
||||||
|
assert len(list(chain.walk(fl.Linear))) == 3
|
||||||
|
assert len(list(chain.walk(predicate))) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_chain_layers() -> None:
|
def test_chain_layers() -> None:
|
||||||
chain = fl.Chain(
|
chain = fl.Chain(
|
||||||
fl.Chain(fl.Chain(fl.Chain())),
|
fl.Chain(fl.Chain(fl.Chain())),
|
||||||
|
|
Loading…
Reference in a new issue