cosmetics

This commit is contained in:
Pierre Chapuis 2023-08-23 14:56:31 +02:00
parent 16618d73de
commit 337d2aea58

View file

@ -27,13 +27,13 @@ def test_chain_remove_replace():
assert isinstance(chain.Chain.Chain[1], fl.Conv2d) assert isinstance(chain.Chain.Chain[1], fl.Conv2d)
def test_chain_structural_copy(): def test_chain_structural_copy() -> None:
m = fl.Chain( m = fl.Chain(
fl.Sum( fl.Sum(
fl.Linear(in_features=4, out_features=8), fl.Linear(4, 8),
fl.Linear(in_features=4, out_features=8), fl.Linear(4, 8),
), ),
fl.Linear(in_features=8, out_features=12), fl.Linear(8, 12),
) )
x = torch.randn(7, 4) x = torch.randn(7, 4)
@ -57,10 +57,8 @@ def test_chain_structural_copy():
torch.equal(y2, y) torch.equal(y2, y)
def test_chain_find(): def test_chain_find() -> None:
chain = fl.Chain( chain = fl.Chain(fl.Linear(1, 1))
fl.Linear(1, 1),
)
assert isinstance(chain.find(fl.Linear), fl.Linear) assert isinstance(chain.find(fl.Linear), fl.Linear)
assert chain.find(fl.Conv2d) is None assert chain.find(fl.Conv2d) is None
@ -68,14 +66,14 @@ def test_chain_find():
def test_chain_slice() -> None: def test_chain_slice() -> None:
chain = fl.Chain( chain = fl.Chain(
fl.Linear(in_features=1, out_features=1), fl.Linear(1, 1),
fl.Linear(in_features=1, out_features=1), fl.Linear(1, 1),
fl.Linear(in_features=1, out_features=1), fl.Linear(1, 1),
fl.Chain( fl.Chain(
fl.Linear(in_features=1, out_features=1), fl.Linear(1, 1),
fl.Linear(in_features=1, out_features=1), fl.Linear(1, 1),
), ),
fl.Linear(in_features=1, out_features=1), fl.Linear(1, 1),
) )
x = torch.randn(1, 1) x = torch.randn(1, 1)
@ -89,11 +87,11 @@ def test_chain_slice() -> None:
def test_chain_walk_stop_iteration() -> None: def test_chain_walk_stop_iteration() -> None:
chain = fl.Chain( chain = fl.Chain(
fl.Sum( fl.Sum(
fl.Chain(fl.Linear(in_features=1, out_features=1)), fl.Chain(fl.Linear(1, 1)),
fl.Linear(in_features=1, out_features=1), fl.Linear(1, 1),
), ),
fl.Chain(), fl.Chain(),
fl.Linear(in_features=1, out_features=1), fl.Linear(1, 1),
) )
def predicate(m: fl.Module, p: fl.Chain) -> bool: def predicate(m: fl.Module, p: fl.Chain) -> bool:
@ -109,7 +107,7 @@ def test_chain_layers() -> None:
chain = fl.Chain( chain = fl.Chain(
fl.Chain(fl.Chain(fl.Chain())), fl.Chain(fl.Chain(fl.Chain())),
fl.Chain(), fl.Chain(),
fl.Linear(in_features=1, out_features=1), fl.Linear(1, 1),
) )
assert len(list(chain.layers(fl.Chain))) == 2 assert len(list(chain.layers(fl.Chain))) == 2