diff --git a/tests/fluxion/layers/test_chain.py b/tests/fluxion/layers/test_chain.py index c443c1d..504cb68 100644 --- a/tests/fluxion/layers/test_chain.py +++ b/tests/fluxion/layers/test_chain.py @@ -2,29 +2,34 @@ import torch import refiners.fluxion.layers as fl -def test_chain_remove_replace(): +def test_chain_remove() -> None: chain = fl.Chain( fl.Linear(1, 1), fl.Linear(1, 1), - fl.Chain( - fl.Linear(1, 1), - fl.Linear(1, 1), - fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)), - ), - fl.Conv2d(1, 1, 1), + fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)), ) - assert len(chain) == 4 - assert len(chain.Chain) == 3 - chain.remove(chain[-1]) assert len(chain) == 3 - assert len(chain.Chain) == 3 + assert "Linear_1" in chain._modules + assert "Linear" not in chain._modules - assert isinstance(chain.Chain.Chain[1], fl.Linear) - chain.Chain.Chain.replace(chain.Chain.Chain[1], fl.Conv2d(1, 1, 1)) + chain.remove(chain.Linear_2) + assert len(chain) == 2 + assert "Linear" in chain._modules + assert "Linear_1" not in chain._modules + + +def test_chain_replace() -> None: + chain = fl.Chain( + fl.Linear(1, 1), + fl.Linear(1, 1), + fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)), + ) + + assert isinstance(chain.Chain[1], fl.Linear) + chain.Chain.replace(chain.Chain[1], fl.Conv2d(1, 1, 1)) assert len(chain) == 3 - assert len(chain.Chain) == 3 - assert isinstance(chain.Chain.Chain[1], fl.Conv2d) + assert isinstance(chain.Chain[1], fl.Conv2d) def test_chain_structural_copy() -> None: