create test_module

This commit is contained in:
Pierre Chapuis 2024-01-29 17:32:41 +01:00 committed by Cédric Deltheil
parent be961af4d9
commit bba478abf2
2 changed files with 15 additions and 14 deletions

View file

@ -243,17 +243,3 @@ def test_debug_print() -> None:
)
assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage]
def test_module_get_path() -> None:
chain = fl.Chain(
fl.Sum(
fl.Linear(1, 1),
fl.Linear(1, 1),
),
fl.Sum(),
)
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
assert chain.Sum_1.get_path() == "Chain.Sum_1"

View file

@ -0,0 +1,15 @@
import refiners.fluxion.layers as fl
def test_module_get_path() -> None:
chain = fl.Chain(
fl.Sum(
fl.Linear(1, 1),
fl.Linear(1, 1),
),
fl.Sum(),
)
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
assert chain.Sum_1.get_path() == "Chain.Sum_1"