From 457c3f5cbd9fec91399aad815549e8e281c6ffb0 Mon Sep 17 00:00:00 2001 From: Colle Date: Thu, 11 Jan 2024 22:37:35 +0100 Subject: [PATCH] display weighted module dtype and device (#173) Co-authored-by: Benjamin Trom --- src/refiners/fluxion/layers/module.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/refiners/fluxion/layers/module.py b/src/refiners/fluxion/layers/module.py index 1fc663e..bb612fd 100644 --- a/src/refiners/fluxion/layers/module.py +++ b/src/refiners/fluxion/layers/module.py @@ -164,6 +164,9 @@ class WeightedModule(Module): def dtype(self) -> DType: return self.weight.dtype + def __str__(self) -> str: + return f"{super().__str__().removesuffix(')')}, device={self.device}, dtype={str(self.dtype).removeprefix('torch.')})" + class TreeNode(TypedDict): value: str