mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
add a get_path
helper to modules
This commit is contained in:
parent
0ee2d5e075
commit
ce0339b4cc
|
@ -94,6 +94,21 @@ class Module(TorchModule):
|
|||
"""
|
||||
return False
|
||||
|
||||
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
|
||||
"""Helper for debugging purpose only.
|
||||
|
||||
Returns the path of the module in the chain as a string.
|
||||
|
||||
If `top` is set then the path will be relative to `top`,
|
||||
otherwise it will be relative to the root of the chain.
|
||||
"""
|
||||
if (parent is None) or (self == top):
|
||||
return self.__class__.__name__
|
||||
for k, m in parent._modules.items(): # type: ignore
|
||||
if m is self:
|
||||
return parent.get_path(parent=parent.parent, top=top) + "." + k
|
||||
raise ValueError(f"{self} not found in {parent}")
|
||||
|
||||
|
||||
class ContextModule(Module):
|
||||
# we store parent into a one element list to avoid pytorch thinking it's a submodule
|
||||
|
@ -154,6 +169,9 @@ class ContextModule(Module):
|
|||
|
||||
return clone
|
||||
|
||||
def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str:
|
||||
return super().get_path(parent=parent or self.parent, top=top)
|
||||
|
||||
|
||||
class WeightedModule(Module):
|
||||
@property
|
||||
|
|
|
@ -243,3 +243,17 @@ def test_debug_print() -> None:
|
|||
)
|
||||
|
||||
assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
def test_module_get_path() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Sum(
|
||||
fl.Linear(1, 1),
|
||||
fl.Linear(1, 1),
|
||||
),
|
||||
fl.Sum(),
|
||||
)
|
||||
|
||||
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
|
||||
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
|
||||
assert chain.Sum_1.get_path() == "Chain.Sum_1"
|
||||
|
|
Loading…
Reference in a new issue