convert_dinov2: ignore pyright errors

And save converted weights into safetensors instead of pickle
This commit is contained in:
Cédric Deltheil 2023-12-14 17:39:41 +01:00 committed by Cédric Deltheil
parent 9337d65e0e
commit e978b3665d

View file

@ -2,6 +2,8 @@ import argparse
import torch import torch
from refiners.fluxion.utils import save_to_safetensors
def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None: def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
"""Convert a DINOv2 weights from facebook to refiners.""" """Convert a DINOv2 weights from facebook to refiners."""
@ -126,9 +128,9 @@ def main() -> None:
parser.add_argument("--output_path", type=str, required=True) parser.add_argument("--output_path", type=str, required=True)
args = parser.parse_args() args = parser.parse_args()
weights = torch.load(args.weights_path) weights = torch.load(args.weights_path) # type: ignore
convert_dinov2_facebook(weights) convert_dinov2_facebook(weights)
torch.save(weights, args.output_path) save_to_safetensors(path=args.output_path, tensors=weights)
if __name__ == "__main__": if __name__ == "__main__":