display weighted module dtype and device (#173)

Co-authored-by: Benjamin Trom <benjamintrom@gmail.com>
This commit is contained in:
Colle 2024-01-11 22:37:35 +01:00 committed by GitHub
parent 14ce2f50f9
commit 457c3f5cbd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -164,6 +164,9 @@ class WeightedModule(Module):
def dtype(self) -> DType: def dtype(self) -> DType:
return self.weight.dtype return self.weight.dtype
def __str__(self) -> str:
return f"{super().__str__().removesuffix(')')}, device={self.device}, dtype={str(self.dtype).removeprefix('torch.')})"
class TreeNode(TypedDict): class TreeNode(TypedDict):
value: str value: str