add a get_path helper to modules

This commit is contained in:
Pierre Chapuis 2024-01-26 17:00:36 +01:00
parent 0ee2d5e075
commit ce0339b4cc
2 changed files with 32 additions and 0 deletions

View file

@ -94,6 +94,21 @@ class Module(TorchModule):
"""
return False
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
"""Helper for debugging purpose only.
Returns the path of the module in the chain as a string.
If `top` is set then the path will be relative to `top`,
otherwise it will be relative to the root of the chain.
"""
if (parent is None) or (self == top):
return self.__class__.__name__
for k, m in parent._modules.items(): # type: ignore
if m is self:
return parent.get_path(parent=parent.parent, top=top) + "." + k
raise ValueError(f"{self} not found in {parent}")
class ContextModule(Module):
# we store parent into a one element list to avoid pytorch thinking it's a submodule
@ -154,6 +169,9 @@ class ContextModule(Module):
return clone
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
return super().get_path(parent=parent or self.parent, top=top)
class WeightedModule(Module):
@property

View file

@ -243,3 +243,17 @@ def test_debug_print() -> None:
)
assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage]
def test_module_get_path() -> None:
chain = fl.Chain(
fl.Sum(
fl.Linear(1, 1),
fl.Linear(1, 1),
),
fl.Sum(),
)
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
assert chain.Sum_1.get_path() == "Chain.Sum_1"