mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-10 07:21:59 +00:00
82 lines
3 KiB
Python
82 lines
3 KiB
Python
|
import argparse
|
||
|
|
||
|
from torch import Tensor
|
||
|
|
||
|
from refiners.fluxion.utils import load_tensors, save_to_safetensors
|
||
|
|
||
|
|
||
|
def main() -> None:
|
||
|
parser = argparse.ArgumentParser(description="Convert HQ SAM model to Refiners state_dict format")
|
||
|
parser.add_argument(
|
||
|
"--from",
|
||
|
type=str,
|
||
|
dest="source_path",
|
||
|
required=True,
|
||
|
default="sam_hq_vit_h.pth",
|
||
|
help="Path to the source model checkpoint.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--to",
|
||
|
type=str,
|
||
|
dest="output_path",
|
||
|
required=True,
|
||
|
default="refiners_sam_hq_vit_h.safetensors",
|
||
|
help="Path to save the converted model in Refiners format.",
|
||
|
)
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
source_state_dict = load_tensors(args.source_path)
|
||
|
|
||
|
state_dict: dict[str, Tensor] = {}
|
||
|
|
||
|
for suffix in ["weight", "bias"]:
|
||
|
state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_1.{suffix}"] = source_state_dict[
|
||
|
f"mask_decoder.compress_vit_feat.0.{suffix}"
|
||
|
]
|
||
|
state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_1.{suffix}"] = source_state_dict[
|
||
|
f"mask_decoder.embedding_encoder.0.{suffix}"
|
||
|
]
|
||
|
state_dict[f"EmbeddingMaskfeature.Conv2d_1.{suffix}"] = source_state_dict[
|
||
|
f"mask_decoder.embedding_maskfeature.0.{suffix}"
|
||
|
]
|
||
|
|
||
|
state_dict[f"HQFeatures.CompressViTFeat.LayerNorm2d.{suffix}"] = source_state_dict[
|
||
|
f"mask_decoder.compress_vit_feat.1.{suffix}"
|
||
|
]
|
||
|
state_dict[f"HQFeatures.EmbeddingEncoder.LayerNorm2d.{suffix}"] = source_state_dict[
|
||
|
f"mask_decoder.embedding_encoder.1.{suffix}"
|
||
|
]
|
||
|
state_dict[f"EmbeddingMaskfeature.LayerNorm2d.{suffix}"] = source_state_dict[
|
||
|
f"mask_decoder.embedding_maskfeature.1.{suffix}"
|
||
|
]
|
||
|
|
||
|
state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_2.{suffix}"] = source_state_dict[
|
||
|
f"mask_decoder.compress_vit_feat.3.{suffix}"
|
||
|
]
|
||
|
state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_2.{suffix}"] = source_state_dict[
|
||
|
f"mask_decoder.embedding_encoder.3.{suffix}"
|
||
|
]
|
||
|
state_dict[f"EmbeddingMaskfeature.Conv2d_2.{suffix}"] = source_state_dict[
|
||
|
f"mask_decoder.embedding_maskfeature.3.{suffix}"
|
||
|
]
|
||
|
|
||
|
state_dict = {f"Chain.HQSAMMaskPrediction.Chain.DenseEmbeddingUpscalingHQ.{k}": v for k, v in state_dict.items()}
|
||
|
|
||
|
# HQ Token
|
||
|
state_dict["MaskDecoderTokensExtender.hq_token.weight"] = source_state_dict["mask_decoder.hf_token.weight"]
|
||
|
|
||
|
# HQ MLP
|
||
|
for i in range(3):
|
||
|
state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.weight"] = source_state_dict[
|
||
|
f"mask_decoder.hf_mlp.layers.{i}.weight"
|
||
|
]
|
||
|
state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.bias"] = source_state_dict[
|
||
|
f"mask_decoder.hf_mlp.layers.{i}.bias"
|
||
|
]
|
||
|
|
||
|
save_to_safetensors(path=args.output_path, tensors=state_dict)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|