mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
parent
cf270885a4
commit
a542337a83
|
@ -243,11 +243,15 @@ class Chain(ContextModule):
|
||||||
raise ChainError(message) from None
|
raise ChainError(message) from None
|
||||||
|
|
||||||
def forward(self, *args: Any) -> Any:
|
def forward(self, *args: Any) -> Any:
|
||||||
result: tuple[Any] | Any = None
|
result: Any = None
|
||||||
intermediate_args: tuple[Any, ...] = args
|
intermediate_args: tuple[Any, ...] = args
|
||||||
for name, layer in self._modules.items():
|
for name, layer in self._modules.items():
|
||||||
result = self._call_layer(layer, name, *intermediate_args)
|
result = self._call_layer(layer, name, *intermediate_args)
|
||||||
intermediate_args = (result,) if not isinstance(result, tuple) else result
|
if isinstance(result, tuple):
|
||||||
|
result = cast(tuple[Any], result)
|
||||||
|
intermediate_args = result
|
||||||
|
else:
|
||||||
|
intermediate_args = (result,)
|
||||||
|
|
||||||
self._reset_context()
|
self._reset_context()
|
||||||
return result
|
return result
|
||||||
|
|
Loading…
Reference in a new issue