diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index 157b599..c857ca8 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -411,7 +411,8 @@ class Distribute(Chain): _tag = "DISTR" def forward(self, *args: Any) -> tuple[Tensor, ...]: - assert len(args) == len(self._modules), "Number of positional arguments must match number of sub-modules." + n, m = len(args), len(self._modules) + assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})." return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())]) def _show_only_tag(self) -> bool: