Distribute: improve sanity check error message

E.g.:

    AssertionError: Number of positional arguments (1) must match number of sub-modules (2).
This commit is contained in:
Cédric Deltheil 2023-09-28 11:40:56 +02:00 committed by Cédric Deltheil
parent 620e58a593
commit 88e454f1cb

View file

@ -411,7 +411,8 @@ class Distribute(Chain):
_tag = "DISTR"
def forward(self, *args: Any) -> tuple[Tensor, ...]:
assert len(args) == len(self._modules), "Number of positional arguments must match number of sub-modules."
n, m = len(args), len(self._modules)
assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})."
return tuple([self.call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
def _show_only_tag(self) -> bool: