mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-13 00:28:14 +00:00
convert_dinov2: tweak command-line args
i.e. mimic the other conversion scripts
This commit is contained in:
parent
5ca1549c96
commit
832f012fe4
|
@ -1,4 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -124,12 +125,35 @@ def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--weights_path", type=str, required=True)
|
parser.add_argument(
|
||||||
parser.add_argument("--output_path", type=str, required=True)
|
"--from",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
dest="source_path",
|
||||||
|
help=(
|
||||||
|
"Official checkpoint from https://github.com/facebookresearch/dinov2"
|
||||||
|
" e.g. /path/to/dinov2_vits14_pretrain.pth"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--to",
|
||||||
|
type=str,
|
||||||
|
dest="output_path",
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Path to save the converted model. If not specified, the output path will be the source path with the"
|
||||||
|
" extension changed to .safetensors."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument("--half", action="store_true", dest="half")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
weights = torch.load(args.weights_path) # type: ignore
|
weights = torch.load(args.source_path) # type: ignore
|
||||||
convert_dinov2_facebook(weights)
|
convert_dinov2_facebook(weights)
|
||||||
|
if args.half:
|
||||||
|
weights = {key: value.half() for key, value in weights.items()}
|
||||||
|
if args.output_path is None:
|
||||||
|
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
||||||
save_to_safetensors(path=args.output_path, tensors=weights)
|
save_to_safetensors(path=args.output_path, tensors=weights)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue