split test in two

This commit is contained in:
Pierre Chapuis 2023-08-23 14:57:04 +02:00
parent 337d2aea58
commit e05c410a86

View file

@ -2,29 +2,34 @@ import torch
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
def test_chain_remove_replace(): def test_chain_remove() -> None:
chain = fl.Chain( chain = fl.Chain(
fl.Linear(1, 1), fl.Linear(1, 1),
fl.Linear(1, 1), fl.Linear(1, 1),
fl.Chain( fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
)
assert len(chain) == 3
assert "Linear_1" in chain._modules
assert "Linear" not in chain._modules
chain.remove(chain.Linear_2)
assert len(chain) == 2
assert "Linear" in chain._modules
assert "Linear_1" not in chain._modules
def test_chain_replace() -> None:
chain = fl.Chain(
fl.Linear(1, 1), fl.Linear(1, 1),
fl.Linear(1, 1), fl.Linear(1, 1),
fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)), fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
),
fl.Conv2d(1, 1, 1),
) )
assert len(chain) == 4
assert len(chain.Chain) == 3
chain.remove(chain[-1]) assert isinstance(chain.Chain[1], fl.Linear)
chain.Chain.replace(chain.Chain[1], fl.Conv2d(1, 1, 1))
assert len(chain) == 3 assert len(chain) == 3
assert len(chain.Chain) == 3 assert isinstance(chain.Chain[1], fl.Conv2d)
assert isinstance(chain.Chain.Chain[1], fl.Linear)
chain.Chain.Chain.replace(chain.Chain.Chain[1], fl.Conv2d(1, 1, 1))
assert len(chain) == 3
assert len(chain.Chain) == 3
assert isinstance(chain.Chain.Chain[1], fl.Conv2d)
def test_chain_structural_copy() -> None: def test_chain_structural_copy() -> None: