cosmetics

This commit is contained in:
Pierre Chapuis 2023-08-23 14:56:31 +02:00
parent 16618d73de
commit 337d2aea58

View file

@ -27,13 +27,13 @@ def test_chain_remove_replace():
assert isinstance(chain.Chain.Chain[1], fl.Conv2d)
def test_chain_structural_copy():
def test_chain_structural_copy() -> None:
m = fl.Chain(
fl.Sum(
fl.Linear(in_features=4, out_features=8),
fl.Linear(in_features=4, out_features=8),
fl.Linear(4, 8),
fl.Linear(4, 8),
),
fl.Linear(in_features=8, out_features=12),
fl.Linear(8, 12),
)
x = torch.randn(7, 4)
@ -57,10 +57,8 @@ def test_chain_structural_copy():
torch.equal(y2, y)
def test_chain_find():
chain = fl.Chain(
fl.Linear(1, 1),
)
def test_chain_find() -> None:
chain = fl.Chain(fl.Linear(1, 1))
assert isinstance(chain.find(fl.Linear), fl.Linear)
assert chain.find(fl.Conv2d) is None
@ -68,14 +66,14 @@ def test_chain_find():
def test_chain_slice() -> None:
chain = fl.Chain(
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
fl.Linear(1, 1),
fl.Linear(1, 1),
fl.Linear(1, 1),
fl.Chain(
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
fl.Linear(1, 1),
fl.Linear(1, 1),
),
fl.Linear(in_features=1, out_features=1),
fl.Linear(1, 1),
)
x = torch.randn(1, 1)
@ -89,11 +87,11 @@ def test_chain_slice() -> None:
def test_chain_walk_stop_iteration() -> None:
chain = fl.Chain(
fl.Sum(
fl.Chain(fl.Linear(in_features=1, out_features=1)),
fl.Linear(in_features=1, out_features=1),
fl.Chain(fl.Linear(1, 1)),
fl.Linear(1, 1),
),
fl.Chain(),
fl.Linear(in_features=1, out_features=1),
fl.Linear(1, 1),
)
def predicate(m: fl.Module, p: fl.Chain) -> bool:
@ -109,7 +107,7 @@ def test_chain_layers() -> None:
chain = fl.Chain(
fl.Chain(fl.Chain(fl.Chain())),
fl.Chain(),
fl.Linear(in_features=1, out_features=1),
fl.Linear(1, 1),
)
assert len(list(chain.layers(fl.Chain))) == 2