mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 09:08:14 +00:00
convert_dinov2: ignore pyright errors
And save converted weights into safetensors instead of pickle
This commit is contained in:
parent
9337d65e0e
commit
e978b3665d
|
@ -2,6 +2,8 @@ import argparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import save_to_safetensors
|
||||||
|
|
||||||
|
|
||||||
def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
|
def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
|
||||||
"""Convert a DINOv2 weights from facebook to refiners."""
|
"""Convert a DINOv2 weights from facebook to refiners."""
|
||||||
|
@ -126,9 +128,9 @@ def main() -> None:
|
||||||
parser.add_argument("--output_path", type=str, required=True)
|
parser.add_argument("--output_path", type=str, required=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
weights = torch.load(args.weights_path)
|
weights = torch.load(args.weights_path) # type: ignore
|
||||||
convert_dinov2_facebook(weights)
|
convert_dinov2_facebook(weights)
|
||||||
torch.save(weights, args.output_path)
|
save_to_safetensors(path=args.output_path, tensors=weights)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in a new issue