mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
convert_dinov2: tweak command-line args
i.e. mimic the other conversion scripts
This commit is contained in:
parent
5ca1549c96
commit
832f012fe4
|
@ -1,4 +1,5 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -124,12 +125,35 @@ def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
|
|||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--weights_path", type=str, required=True)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
required=True,
|
||||
dest="source_path",
|
||||
help=(
|
||||
"Official checkpoint from https://github.com/facebookresearch/dinov2"
|
||||
" e.g. /path/to/dinov2_vits14_pretrain.pth"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default=None,
|
||||
help=(
|
||||
"Path to save the converted model. If not specified, the output path will be the source path with the"
|
||||
" extension changed to .safetensors."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--half", action="store_true", dest="half")
|
||||
args = parser.parse_args()
|
||||
|
||||
weights = torch.load(args.weights_path) # type: ignore
|
||||
weights = torch.load(args.source_path) # type: ignore
|
||||
convert_dinov2_facebook(weights)
|
||||
if args.half:
|
||||
weights = {key: value.half() for key, value in weights.items()}
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
||||
save_to_safetensors(path=args.output_path, tensors=weights)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue