diff --git a/src/refiners/fluxion/layers/module.py b/src/refiners/fluxion/layers/module.py index bb612fd..78ecb0d 100644 --- a/src/refiners/fluxion/layers/module.py +++ b/src/refiners/fluxion/layers/module.py @@ -94,6 +94,21 @@ class Module(TorchModule): """ return False + def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str: + """Helper for debugging purpose only. + + Returns the path of the module in the chain as a string. + + If `top` is set then the path will be relative to `top`, + otherwise it will be relative to the root of the chain. + """ + if (parent is None) or (self == top): + return self.__class__.__name__ + for k, m in parent._modules.items(): # type: ignore + if m is self: + return parent.get_path(parent=parent.parent, top=top) + "." + k + raise ValueError(f"{self} not found in {parent}") + class ContextModule(Module): # we store parent into a one element list to avoid pytorch thinking it's a submodule @@ -154,6 +169,9 @@ class ContextModule(Module): return clone + def get_path(self, parent: "Chain | None" = None, top: "Module | None" = None) -> str: + return super().get_path(parent=parent or self.parent, top=top) + class WeightedModule(Module): @property diff --git a/tests/fluxion/layers/test_chain.py b/tests/fluxion/layers/test_chain.py index 402564c..0694a52 100644 --- a/tests/fluxion/layers/test_chain.py +++ b/tests/fluxion/layers/test_chain.py @@ -243,3 +243,17 @@ def test_debug_print() -> None: ) assert chain._show_error_in_tree("Chain.Linear_2") == EXPECTED_TREE # type: ignore[reportPrivateUsage] + + +def test_module_get_path() -> None: + chain = fl.Chain( + fl.Sum( + fl.Linear(1, 1), + fl.Linear(1, 1), + ), + fl.Sum(), + ) + + assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2" + assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2" + assert chain.Sum_1.get_path() == "Chain.Sum_1"