diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index f0331a7..0cb5f12 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -243,11 +243,15 @@ class Chain(ContextModule): raise ChainError(message) from None def forward(self, *args: Any) -> Any: - result: tuple[Any] | Any = None + result: Any = None intermediate_args: tuple[Any, ...] = args for name, layer in self._modules.items(): result = self._call_layer(layer, name, *intermediate_args) - intermediate_args = (result,) if not isinstance(result, tuple) else result + if isinstance(result, tuple): + result = cast(tuple[Any], result) + intermediate_args = result + else: + intermediate_args = (result,) self._reset_context() return result