refiners/scripts/conversion/convert_hq_segment_anything.py

82 lines
3 KiB
Python
Raw Normal View History

2024-03-21 13:59:36 +00:00
import argparse
from torch import Tensor
from refiners.fluxion.utils import load_tensors, save_to_safetensors
def main() -> None:
parser = argparse.ArgumentParser(description="Convert HQ SAM model to Refiners state_dict format")
parser.add_argument(
"--from",
type=str,
dest="source_path",
required=True,
default="sam_hq_vit_h.pth",
help="Path to the source model checkpoint.",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
required=True,
default="refiners_sam_hq_vit_h.safetensors",
help="Path to save the converted model in Refiners format.",
)
args = parser.parse_args()
source_state_dict = load_tensors(args.source_path)
state_dict: dict[str, Tensor] = {}
for suffix in ["weight", "bias"]:
state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_1.{suffix}"] = source_state_dict[
f"mask_decoder.compress_vit_feat.0.{suffix}"
]
state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_1.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_encoder.0.{suffix}"
]
state_dict[f"EmbeddingMaskfeature.Conv2d_1.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_maskfeature.0.{suffix}"
]
state_dict[f"HQFeatures.CompressViTFeat.LayerNorm2d.{suffix}"] = source_state_dict[
f"mask_decoder.compress_vit_feat.1.{suffix}"
]
state_dict[f"HQFeatures.EmbeddingEncoder.LayerNorm2d.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_encoder.1.{suffix}"
]
state_dict[f"EmbeddingMaskfeature.LayerNorm2d.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_maskfeature.1.{suffix}"
]
state_dict[f"HQFeatures.CompressViTFeat.ConvTranspose2d_2.{suffix}"] = source_state_dict[
f"mask_decoder.compress_vit_feat.3.{suffix}"
]
state_dict[f"HQFeatures.EmbeddingEncoder.ConvTranspose2d_2.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_encoder.3.{suffix}"
]
state_dict[f"EmbeddingMaskfeature.Conv2d_2.{suffix}"] = source_state_dict[
f"mask_decoder.embedding_maskfeature.3.{suffix}"
]
state_dict = {f"Chain.HQSAMMaskPrediction.Chain.DenseEmbeddingUpscalingHQ.{k}": v for k, v in state_dict.items()}
# HQ Token
state_dict["MaskDecoderTokensExtender.hq_token.weight"] = source_state_dict["mask_decoder.hf_token.weight"]
# HQ MLP
for i in range(3):
state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.weight"] = source_state_dict[
f"mask_decoder.hf_mlp.layers.{i}.weight"
]
state_dict[f"Chain.HQSAMMaskPrediction.HQTokenMLP.MultiLinear.Linear_{i+1}.bias"] = source_state_dict[
f"mask_decoder.hf_mlp.layers.{i}.bias"
]
save_to_safetensors(path=args.output_path, tensors=state_dict)
if __name__ == "__main__":
main()