diff --git a/src/refiners/fluxion/layers/module.py b/src/refiners/fluxion/layers/module.py index c9689e1..46b31cd 100644 --- a/src/refiners/fluxion/layers/module.py +++ b/src/refiners/fluxion/layers/module.py @@ -42,7 +42,7 @@ class Module(TorchModule): def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) - def load_from_safetensors(self, tensors_path: str | Path, strict: bool = True) -> "Module": + def load_from_safetensors(self: T, tensors_path: str | Path, strict: bool = True) -> T: """Load the module's state from a SafeTensors file. Args: