convert_dinov2: tweak command-line args

i.e. mimic the other conversion scripts
This commit is contained in:
Cédric Deltheil 2023-12-16 10:10:27 +01:00 committed by Cédric Deltheil
parent 5ca1549c96
commit 832f012fe4

View file

@ -1,4 +1,5 @@
import argparse
from pathlib import Path
import torch
@ -124,12 +125,35 @@ def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--weights_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument(
"--from",
type=str,
required=True,
dest="source_path",
help=(
"Official checkpoint from https://github.com/facebookresearch/dinov2"
" e.g. /path/to/dinov2_vits14_pretrain.pth"
),
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Path to save the converted model. If not specified, the output path will be the source path with the"
" extension changed to .safetensors."
),
)
parser.add_argument("--half", action="store_true", dest="half")
args = parser.parse_args()
weights = torch.load(args.weights_path) # type: ignore
weights = torch.load(args.source_path) # type: ignore
convert_dinov2_facebook(weights)
if args.half:
weights = {key: value.half() for key, value in weights.items()}
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"
save_to_safetensors(path=args.output_path, tensors=weights)