From 832f012fe4b9d53a799c885ccdc40a3e95b35e42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Sat, 16 Dec 2023 10:10:27 +0100 Subject: [PATCH] convert_dinov2: tweak command-line args i.e. mimic the other conversion scripts --- scripts/conversion/convert_dinov2.py | 30 +++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/scripts/conversion/convert_dinov2.py b/scripts/conversion/convert_dinov2.py index 5e8af35..0d71d47 100644 --- a/scripts/conversion/convert_dinov2.py +++ b/scripts/conversion/convert_dinov2.py @@ -1,4 +1,5 @@ import argparse +from pathlib import Path import torch @@ -124,12 +125,35 @@ def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None: def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument("--weights_path", type=str, required=True) - parser.add_argument("--output_path", type=str, required=True) + parser.add_argument( + "--from", + type=str, + required=True, + dest="source_path", + help=( + "Official checkpoint from https://github.com/facebookresearch/dinov2" + " e.g. /path/to/dinov2_vits14_pretrain.pth" + ), + ) + parser.add_argument( + "--to", + type=str, + dest="output_path", + default=None, + help=( + "Path to save the converted model. If not specified, the output path will be the source path with the" + " extension changed to .safetensors." + ), + ) + parser.add_argument("--half", action="store_true", dest="half") args = parser.parse_args() - weights = torch.load(args.weights_path) # type: ignore + weights = torch.load(args.source_path) # type: ignore convert_dinov2_facebook(weights) + if args.half: + weights = {key: value.half() for key, value in weights.items()} + if args.output_path is None: + args.output_path = f"{Path(args.source_path).stem}.safetensors" save_to_safetensors(path=args.output_path, tensors=weights)