mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add a get_path
helper to modules
This commit is contained in:
parent
0ee2d5e075
commit
ce0339b4cc
|
@ -94,6 +94,21 @@ class Module(TorchModule):
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
|
||||||
|
"""Helper for debugging purpose only.
|
||||||
|
|
||||||
|
Returns the path of the module in the chain as a string.
|
||||||
|
|
||||||
|
If `top` is set then the path will be relative to `top`,
|
||||||
|
otherwise it will be relative to the root of the chain.
|
||||||
|
"""
|
||||||
|
if (parent is None) or (self == top):
|
||||||
|
return self.__class__.__name__
|
||||||
|
for k, m in parent._modules.items(): # type: ignore
|
||||||
|
if m is self:
|
||||||
|
return parent.get_path(parent=parent.parent, top=top) + "." + k
|
||||||
|
raise ValueError(f"{self} not found in {parent}")
|
||||||
|
|
||||||
|
|
||||||
class ContextModule(Module):
|
class ContextModule(Module):
|
||||||
# we store parent into a one element list to avoid pytorch thinking it's a submodule
|
# we store parent into a one element list to avoid pytorch thinking it's a submodule
|
||||||
|
@ -154,6 +169,9 @@ class ContextModule(Module):
|
||||||
|
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
|
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
|
||||||
|
return super().get_path(parent=parent or self.parent, top=top)
|
||||||
|
|
||||||
|
|
||||||
class WeightedModule(Module):
|
class WeightedModule(Module):
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -243,3 +243,17 @@ def test_debug_print() -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage]
|
assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
|
|
||||||
|
def test_module_get_path() -> None:
|
||||||
|
chain = fl.Chain(
|
||||||
|
fl.Sum(
|
||||||
|
fl.Linear(1, 1),
|
||||||
|
fl.Linear(1, 1),
|
||||||
|
),
|
||||||
|
fl.Sum(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
|
||||||
|
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
|
||||||
|
assert chain.Sum_1.get_path() == "Chain.Sum_1"
|
||||||
|
|
Loading…
Reference in a new issue