diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index 6971c4a..cfabe15 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -142,7 +142,7 @@ class Chain(ContextModule): self._reset_context() for module in self: - if isinstance(module, ContextModule) and module._can_refresh_parent and module.parent != self: + if isinstance(module, ContextModule) and module.parent != self: module._set_parent(self) @property diff --git a/src/refiners/fluxion/layers/module.py b/src/refiners/fluxion/layers/module.py index edfa423..edf864b 100644 --- a/src/refiners/fluxion/layers/module.py +++ b/src/refiners/fluxion/layers/module.py @@ -62,6 +62,8 @@ class ContextModule(Module): return self._parent[0] def _set_parent(self, parent: "Chain | None") -> None: + if not self._can_refresh_parent: + return if parent is None: self._parent = [] return