diff --git a/src/refiners/fluxion/layers/module.py b/src/refiners/fluxion/layers/module.py index 1fc663e..bb612fd 100644 --- a/src/refiners/fluxion/layers/module.py +++ b/src/refiners/fluxion/layers/module.py @@ -164,6 +164,9 @@ class WeightedModule(Module): def dtype(self) -> DType: return self.weight.dtype + def __str__(self) -> str: + return f"{super().__str__().removesuffix(')')}, device={self.device}, dtype={str(self.dtype).removeprefix('torch.')})" + class TreeNode(TypedDict): value: str