From bba478abf2d5dd0069dae76a125650d191c9d0d8 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Mon, 29 Jan 2024 17:32:41 +0100 Subject: [PATCH] create test_module --- tests/fluxion/layers/test_chain.py | 14 -------------- tests/fluxion/test_module.py | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 14 deletions(-) create mode 100644 tests/fluxion/test_module.py diff --git a/tests/fluxion/layers/test_chain.py b/tests/fluxion/layers/test_chain.py index 0694a52..402564c 100644 --- a/tests/fluxion/layers/test_chain.py +++ b/tests/fluxion/layers/test_chain.py @@ -243,17 +243,3 @@ def test_debug_print() -> None: ) assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage] - - -def test_module_get_path() -> None: - chain = fl.Chain( - fl.Sum( - fl.Linear(1, 1), - fl.Linear(1, 1), - ), - fl.Sum(), - ) - - assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2" - assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2" - assert chain.Sum_1.get_path() == "Chain.Sum_1" diff --git a/tests/fluxion/test_module.py b/tests/fluxion/test_module.py new file mode 100644 index 0000000..3e7ace7 --- /dev/null +++ b/tests/fluxion/test_module.py @@ -0,0 +1,15 @@ +import refiners.fluxion.layers as fl + + +def test_module_get_path() -> None: + chain = fl.Chain( + fl.Sum( + fl.Linear(1, 1), + fl.Linear(1, 1), + ), + fl.Sum(), + ) + + assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2" + assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2" + assert chain.Sum_1.get_path() == "Chain.Sum_1"