split test in two

This commit is contained in:
Pierre Chapuis 2023-08-23 14:57:04 +02:00
parent 337d2aea58
commit e05c410a86

View file

@ -2,29 +2,34 @@ import torch
import refiners.fluxion.layers as fl
def test_chain_remove_replace():
def test_chain_remove() -> None:
chain = fl.Chain(
fl.Linear(1, 1),
fl.Linear(1, 1),
fl.Chain(
fl.Linear(1, 1),
fl.Linear(1, 1),
fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
),
fl.Conv2d(1, 1, 1),
fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
)
assert len(chain) == 4
assert len(chain.Chain) == 3
chain.remove(chain[-1])
assert len(chain) == 3
assert len(chain.Chain) == 3
assert "Linear_1" in chain._modules
assert "Linear" not in chain._modules
assert isinstance(chain.Chain.Chain[1], fl.Linear)
chain.Chain.Chain.replace(chain.Chain.Chain[1], fl.Conv2d(1, 1, 1))
chain.remove(chain.Linear_2)
assert len(chain) == 2
assert "Linear" in chain._modules
assert "Linear_1" not in chain._modules
def test_chain_replace() -> None:
chain = fl.Chain(
fl.Linear(1, 1),
fl.Linear(1, 1),
fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
)
assert isinstance(chain.Chain[1], fl.Linear)
chain.Chain.replace(chain.Chain[1], fl.Conv2d(1, 1, 1))
assert len(chain) == 3
assert len(chain.Chain) == 3
assert isinstance(chain.Chain.Chain[1], fl.Conv2d)
assert isinstance(chain.Chain[1], fl.Conv2d)
def test_chain_structural_copy() -> None: