mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-12 16:18:22 +00:00
cosmetics
This commit is contained in:
parent
16618d73de
commit
337d2aea58
|
@ -27,13 +27,13 @@ def test_chain_remove_replace():
|
|||
assert isinstance(chain.Chain.Chain[1], fl.Conv2d)
|
||||
|
||||
|
||||
def test_chain_structural_copy():
|
||||
def test_chain_structural_copy() -> None:
|
||||
m = fl.Chain(
|
||||
fl.Sum(
|
||||
fl.Linear(in_features=4, out_features=8),
|
||||
fl.Linear(in_features=4, out_features=8),
|
||||
fl.Linear(4, 8),
|
||||
fl.Linear(4, 8),
|
||||
),
|
||||
fl.Linear(in_features=8, out_features=12),
|
||||
fl.Linear(8, 12),
|
||||
)
|
||||
|
||||
x = torch.randn(7, 4)
|
||||
|
@ -57,10 +57,8 @@ def test_chain_structural_copy():
|
|||
torch.equal(y2, y)
|
||||
|
||||
|
||||
def test_chain_find():
|
||||
chain = fl.Chain(
|
||||
fl.Linear(1, 1),
|
||||
)
|
||||
def test_chain_find() -> None:
|
||||
chain = fl.Chain(fl.Linear(1, 1))
|
||||
|
||||
assert isinstance(chain.find(fl.Linear), fl.Linear)
|
||||
assert chain.find(fl.Conv2d) is None
|
||||
|
@ -68,14 +66,14 @@ def test_chain_find():
|
|||
|
||||
def test_chain_slice() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(1, 1),
|
||||
fl.Linear(1, 1),
|
||||
fl.Linear(1, 1),
|
||||
fl.Chain(
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(1, 1),
|
||||
fl.Linear(1, 1),
|
||||
),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(1, 1),
|
||||
)
|
||||
|
||||
x = torch.randn(1, 1)
|
||||
|
@ -89,11 +87,11 @@ def test_chain_slice() -> None:
|
|||
def test_chain_walk_stop_iteration() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Sum(
|
||||
fl.Chain(fl.Linear(in_features=1, out_features=1)),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Chain(fl.Linear(1, 1)),
|
||||
fl.Linear(1, 1),
|
||||
),
|
||||
fl.Chain(),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(1, 1),
|
||||
)
|
||||
|
||||
def predicate(m: fl.Module, p: fl.Chain) -> bool:
|
||||
|
@ -109,7 +107,7 @@ def test_chain_layers() -> None:
|
|||
chain = fl.Chain(
|
||||
fl.Chain(fl.Chain(fl.Chain())),
|
||||
fl.Chain(),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(1, 1),
|
||||
)
|
||||
|
||||
assert len(list(chain.layers(fl.Chain))) == 2
|
||||
|
|
Loading…
Reference in a new issue