diff --git a/tests/fluxion/layers/test_chain.py b/tests/fluxion/layers/test_chain.py index 0694a52..402564c 100644 --- a/tests/fluxion/layers/test_chain.py +++ b/tests/fluxion/layers/test_chain.py @@ -243,17 +243,3 @@ def test_debug_print() -> None: ) assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage] - - -def test_module_get_path() -> None: - chain = fl.Chain( - fl.Sum( - fl.Linear(1, 1), - fl.Linear(1, 1), - ), - fl.Sum(), - ) - - assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2" - assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2" - assert chain.Sum_1.get_path() == "Chain.Sum_1" diff --git a/tests/fluxion/test_module.py b/tests/fluxion/test_module.py new file mode 100644 index 0000000..3e7ace7 --- /dev/null +++ b/tests/fluxion/test_module.py @@ -0,0 +1,15 @@ +import refiners.fluxion.layers as fl + + +def test_module_get_path() -> None: + chain = fl.Chain( + fl.Sum( + fl.Linear(1, 1), + fl.Linear(1, 1), + ), + fl.Sum(), + ) + + assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2" + assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2" + assert chain.Sum_1.get_path() == "Chain.Sum_1"