diff --git a/scripts/conversion/convert_dinov2.py b/scripts/conversion/convert_dinov2.py index 5e8af35..0d71d47 100644 --- a/scripts/conversion/convert_dinov2.py +++ b/scripts/conversion/convert_dinov2.py @@ -1,4 +1,5 @@ import argparse +from pathlib import Path import torch @@ -124,12 +125,35 @@ def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None: def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument("--weights_path", type=str, required=True) - parser.add_argument("--output_path", type=str, required=True) + parser.add_argument( + "--from", + type=str, + required=True, + dest="source_path", + help=( + "Official checkpoint from https://github.com/facebookresearch/dinov2" + " e.g. /path/to/dinov2_vits14_pretrain.pth" + ), + ) + parser.add_argument( + "--to", + type=str, + dest="output_path", + default=None, + help=( + "Path to save the converted model. If not specified, the output path will be the source path with the" + " extension changed to .safetensors." + ), + ) + parser.add_argument("--half", action="store_true", dest="half") args = parser.parse_args() - weights = torch.load(args.weights_path) # type: ignore + weights = torch.load(args.source_path) # type: ignore convert_dinov2_facebook(weights) + if args.half: + weights = {key: value.half() for key, value in weights.items()} + if args.output_path is None: + args.output_path = f"{Path(args.source_path).stem}.safetensors" save_to_safetensors(path=args.output_path, tensors=weights)