mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
create test_module
This commit is contained in:
parent
be961af4d9
commit
bba478abf2
|
@ -243,17 +243,3 @@ def test_debug_print() -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage]
|
assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
|
|
||||||
def test_module_get_path() -> None:
|
|
||||||
chain = fl.Chain(
|
|
||||||
fl.Sum(
|
|
||||||
fl.Linear(1, 1),
|
|
||||||
fl.Linear(1, 1),
|
|
||||||
),
|
|
||||||
fl.Sum(),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
|
|
||||||
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
|
|
||||||
assert chain.Sum_1.get_path() == "Chain.Sum_1"
|
|
||||||
|
|
15
tests/fluxion/test_module.py
Normal file
15
tests/fluxion/test_module.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
def test_module_get_path() -> None:
|
||||||
|
chain = fl.Chain(
|
||||||
|
fl.Sum(
|
||||||
|
fl.Linear(1, 1),
|
||||||
|
fl.Linear(1, 1),
|
||||||
|
),
|
||||||
|
fl.Sum(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
|
||||||
|
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
|
||||||
|
assert chain.Sum_1.get_path() == "Chain.Sum_1"
|
Loading…
Reference in a new issue