From b4ee65b9b1e1598ae90eb8b723b48778679a89ea Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Thu, 1 Aug 2024 14:04:24 +0000 Subject: [PATCH] improve load_from_safetensors typing --- src/refiners/fluxion/layers/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/refiners/fluxion/layers/module.py b/src/refiners/fluxion/layers/module.py index c9689e1..46b31cd 100644 --- a/src/refiners/fluxion/layers/module.py +++ b/src/refiners/fluxion/layers/module.py @@ -42,7 +42,7 @@ class Module(TorchModule): def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) - def load_from_safetensors(self, tensors_path: str | Path, strict: bool = True) -> "Module": + def load_from_safetensors(self: T, tensors_path: str | Path, strict: bool = True) -> T: """Load the module's state from a SafeTensors file. Args: