From 88e454f1cbb623ce32937821ae784f31b5f1c08d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Thu, 28 Sep 2023 11:40:56 +0200 Subject: [PATCH] Distribute: improve sanity check error message E.g.: AssertionError: Number of positional arguments (1) must match number of sub-modules (2). --- src/refiners/fluxion/layers/chain.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index 157b599..c857ca8 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -411,7 +411,8 @@ class Distribute(Chain): _tag = "DISTR" def forward(self, *args: Any) -> tuple[Tensor, ...]: - assert len(args) == len(self._modules), "Number of positional arguments must match number of sub-modules." + n, m = len(args), len(self._modules) + assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})." return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())]) def _show_only_tag(self) -> bool: