diff --git a/tests/fluxion/layers/test_chain.py b/tests/fluxion/layers/test_chain.py index d977f06..c443c1d 100644 --- a/tests/fluxion/layers/test_chain.py +++ b/tests/fluxion/layers/test_chain.py @@ -27,13 +27,13 @@ def test_chain_remove_replace(): assert isinstance(chain.Chain.Chain[1], fl.Conv2d) -def test_chain_structural_copy(): +def test_chain_structural_copy() -> None: m = fl.Chain( fl.Sum( - fl.Linear(in_features=4, out_features=8), - fl.Linear(in_features=4, out_features=8), + fl.Linear(4, 8), + fl.Linear(4, 8), ), - fl.Linear(in_features=8, out_features=12), + fl.Linear(8, 12), ) x = torch.randn(7, 4) @@ -57,10 +57,8 @@ def test_chain_structural_copy(): torch.equal(y2, y) -def test_chain_find(): - chain = fl.Chain( - fl.Linear(1, 1), - ) +def test_chain_find() -> None: + chain = fl.Chain(fl.Linear(1, 1)) assert isinstance(chain.find(fl.Linear), fl.Linear) assert chain.find(fl.Conv2d) is None @@ -68,14 +66,14 @@ def test_chain_find(): def test_chain_slice() -> None: chain = fl.Chain( - fl.Linear(in_features=1, out_features=1), - fl.Linear(in_features=1, out_features=1), - fl.Linear(in_features=1, out_features=1), + fl.Linear(1, 1), + fl.Linear(1, 1), + fl.Linear(1, 1), fl.Chain( - fl.Linear(in_features=1, out_features=1), - fl.Linear(in_features=1, out_features=1), + fl.Linear(1, 1), + fl.Linear(1, 1), ), - fl.Linear(in_features=1, out_features=1), + fl.Linear(1, 1), ) x = torch.randn(1, 1) @@ -89,11 +87,11 @@ def test_chain_slice() -> None: def test_chain_walk_stop_iteration() -> None: chain = fl.Chain( fl.Sum( - fl.Chain(fl.Linear(in_features=1, out_features=1)), - fl.Linear(in_features=1, out_features=1), + fl.Chain(fl.Linear(1, 1)), + fl.Linear(1, 1), ), fl.Chain(), - fl.Linear(in_features=1, out_features=1), + fl.Linear(1, 1), ) def predicate(m: fl.Module, p: fl.Chain) -> bool: @@ -109,7 +107,7 @@ def test_chain_layers() -> None: chain = fl.Chain( fl.Chain(fl.Chain(fl.Chain())), fl.Chain(), - fl.Linear(in_features=1, out_features=1), + fl.Linear(1, 1), ) assert len(list(chain.layers(fl.Chain))) == 2