improve load_from_safetensors typing
Some checks failed
CI / lint_and_typecheck (push) Has been cancelled
Deploy docs to GitHub Pages / Deploy docs (push) Has been cancelled
Spell checker / Spell check (push) Has been cancelled

This commit is contained in:
limiteinductive 2024-08-01 14:04:24 +00:00 committed by Benjamin Trom
parent 1de567590b
commit b4ee65b9b1

View file

@ -42,7 +42,7 @@ class Module(TorchModule):
def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name, value)
def load_from_safetensors(self, tensors_path: str | Path, strict: bool = True) -> "Module":
def load_from_safetensors(self: T, tensors_path: str | Path, strict: bool = True) -> T:
"""Load the module's state from a SafeTensors file.
Args: