mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
create test_module
This commit is contained in:
parent
be961af4d9
commit
bba478abf2
|
@ -243,17 +243,3 @@ def test_debug_print() -> None:
|
|||
)
|
||||
|
||||
assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
def test_module_get_path() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Sum(
|
||||
fl.Linear(1, 1),
|
||||
fl.Linear(1, 1),
|
||||
),
|
||||
fl.Sum(),
|
||||
)
|
||||
|
||||
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
|
||||
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
|
||||
assert chain.Sum_1.get_path() == "Chain.Sum_1"
|
||||
|
|
15
tests/fluxion/test_module.py
Normal file
15
tests/fluxion/test_module.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
||||
def test_module_get_path() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Sum(
|
||||
fl.Linear(1, 1),
|
||||
fl.Linear(1, 1),
|
||||
),
|
||||
fl.Sum(),
|
||||
)
|
||||
|
||||
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
|
||||
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
|
||||
assert chain.Sum_1.get_path() == "Chain.Sum_1"
|
Loading…
Reference in a new issue