2023-09-21 08:19:19 +00:00
|
|
|
import argparse
|
|
|
|
import types
|
|
|
|
from typing import Any, Callable, cast
|
2023-12-11 10:46:38 +00:00
|
|
|
|
2023-09-21 08:19:19 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
2023-12-11 10:46:38 +00:00
|
|
|
from segment_anything import build_sam_vit_h # type: ignore
|
|
|
|
from segment_anything.modeling.common import LayerNorm2d # type: ignore
|
2023-09-21 08:19:19 +00:00
|
|
|
from torch import Tensor
|
|
|
|
|
|
|
|
import refiners.fluxion.layers as fl
|
|
|
|
from refiners.fluxion.model_converter import ModelConverter
|
2024-01-19 15:37:01 +00:00
|
|
|
from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors
|
2024-02-05 16:10:05 +00:00
|
|
|
from refiners.foundationals.segment_anything.image_encoder import PositionalEncoder, SAMViTH
|
2023-09-21 08:19:19 +00:00
|
|
|
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
2023-12-11 10:46:38 +00:00
|
|
|
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
2023-09-21 08:19:19 +00:00
|
|
|
|
|
|
|
|
|
|
|
class FacebookSAM(nn.Module):
|
|
|
|
image_encoder: nn.Module
|
|
|
|
prompt_encoder: nn.Module
|
|
|
|
mask_decoder: nn.Module
|
|
|
|
|
|
|
|
|
|
|
|
build_sam_vit_h = cast(Callable[[], FacebookSAM], build_sam_vit_h)
|
|
|
|
|
|
|
|
|
|
|
|
assert issubclass(LayerNorm2d, nn.Module)
|
|
|
|
custom_layers = {LayerNorm2d: fl.LayerNorm2d}
|
|
|
|
|
|
|
|
|
|
|
|
class Args(argparse.Namespace):
|
|
|
|
source_path: str
|
|
|
|
output_path: str
|
|
|
|
half: bool
|
|
|
|
verbose: bool
|
|
|
|
|
|
|
|
|
|
|
|
def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
|
2024-01-05 17:45:03 +00:00
|
|
|
manual_seed(seed=0)
|
|
|
|
refiners_mask_encoder = MaskEncoder()
|
|
|
|
|
|
|
|
converter = ModelConverter(
|
|
|
|
source_model=prompt_encoder.mask_downscaling,
|
|
|
|
target_model=refiners_mask_encoder,
|
|
|
|
custom_layer_mapping=custom_layers, # type: ignore
|
|
|
|
)
|
|
|
|
|
|
|
|
x = torch.randn(1, 256, 256)
|
|
|
|
mapping = converter.map_state_dicts(source_args=(x,))
|
|
|
|
assert mapping
|
|
|
|
|
|
|
|
source_state_dict = prompt_encoder.mask_downscaling.state_dict()
|
|
|
|
target_state_dict = refiners_mask_encoder.state_dict()
|
|
|
|
|
|
|
|
# Mapping handled manually (see below) because nn.Parameter is a special case
|
|
|
|
del target_state_dict["no_mask_embedding"]
|
|
|
|
|
|
|
|
converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
|
|
|
|
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
|
|
|
)
|
|
|
|
|
2023-09-21 08:19:19 +00:00
|
|
|
state_dict: dict[str, Tensor] = {
|
|
|
|
"no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore
|
|
|
|
}
|
|
|
|
|
2024-01-05 17:45:03 +00:00
|
|
|
state_dict.update(converted_source)
|
|
|
|
|
|
|
|
refiners_mask_encoder.load_state_dict(state_dict=state_dict)
|
2023-09-21 08:19:19 +00:00
|
|
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
|
|
def convert_point_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
|
|
|
|
manual_seed(seed=0)
|
2023-12-08 11:26:50 +00:00
|
|
|
point_embeddings: list[Tensor] = [pe.weight for pe in prompt_encoder.point_embeddings] + [
|
|
|
|
prompt_encoder.not_a_point_embed.weight
|
|
|
|
] # type: ignore
|
2023-09-21 08:19:19 +00:00
|
|
|
pe = prompt_encoder.pe_layer.positional_encoding_gaussian_matrix # type: ignore
|
|
|
|
assert isinstance(pe, Tensor)
|
|
|
|
state_dict: dict[str, Tensor] = {
|
2023-10-19 08:17:25 +00:00
|
|
|
"Residual.PointTypeEmbedding.weight": nn.Parameter(data=torch.cat(tensors=point_embeddings, dim=0)),
|
2023-09-21 08:19:19 +00:00
|
|
|
"CoordinateEncoder.Linear.weight": nn.Parameter(data=pe.T.contiguous()),
|
|
|
|
}
|
|
|
|
|
|
|
|
refiners_prompt_encoder = PointEncoder()
|
|
|
|
refiners_prompt_encoder.load_state_dict(state_dict=state_dict)
|
|
|
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
|
|
def convert_vit(vit: nn.Module) -> dict[str, Tensor]:
|
|
|
|
manual_seed(seed=0)
|
|
|
|
refiners_sam_vit_h = SAMViTH()
|
|
|
|
|
|
|
|
converter = ModelConverter(
|
|
|
|
source_model=vit,
|
|
|
|
target_model=refiners_sam_vit_h,
|
|
|
|
custom_layer_mapping=custom_layers, # type: ignore
|
|
|
|
)
|
|
|
|
converter.skip_init_check = True
|
|
|
|
|
|
|
|
x = torch.randn(1, 3, 1024, 1024)
|
|
|
|
mapping = converter.map_state_dicts(source_args=(x,))
|
|
|
|
assert mapping
|
|
|
|
|
2023-12-08 09:12:37 +00:00
|
|
|
mapping["PositionalEncoder.Parameter.weight"] = "pos_embed"
|
2023-09-21 08:19:19 +00:00
|
|
|
|
|
|
|
target_state_dict = refiners_sam_vit_h.state_dict()
|
2023-12-08 09:12:37 +00:00
|
|
|
del target_state_dict["PositionalEncoder.Parameter.weight"]
|
2023-09-21 08:19:19 +00:00
|
|
|
|
|
|
|
source_state_dict = vit.state_dict()
|
|
|
|
pos_embed = source_state_dict["pos_embed"]
|
|
|
|
del source_state_dict["pos_embed"]
|
|
|
|
|
|
|
|
target_rel_keys = [
|
|
|
|
(
|
2023-10-19 08:17:25 +00:00
|
|
|
f"Transformer.TransformerLayer_{i}.Residual_1.FusedSelfAttention.RelativePositionAttention.horizontal_embedding",
|
|
|
|
f"Transformer.TransformerLayer_{i}.Residual_1.FusedSelfAttention.RelativePositionAttention.vertical_embedding",
|
2023-09-21 08:19:19 +00:00
|
|
|
)
|
|
|
|
for i in range(1, 33)
|
|
|
|
]
|
|
|
|
source_rel_keys = [(f"blocks.{i}.attn.rel_pos_w", f"blocks.{i}.attn.rel_pos_h") for i in range(32)]
|
|
|
|
|
|
|
|
rel_items: dict[str, Tensor] = {}
|
|
|
|
|
|
|
|
for (key_w, key_h), (target_key_w, target_key_h) in zip(source_rel_keys, target_rel_keys):
|
|
|
|
rel_items[target_key_w] = source_state_dict[key_w]
|
|
|
|
rel_items[target_key_h] = source_state_dict[key_h]
|
|
|
|
del source_state_dict[key_w]
|
|
|
|
del source_state_dict[key_h]
|
|
|
|
del target_state_dict[target_key_w]
|
|
|
|
del target_state_dict[target_key_h]
|
|
|
|
|
|
|
|
converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
|
|
|
|
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
|
|
|
)
|
|
|
|
|
2024-02-05 16:10:05 +00:00
|
|
|
positional_encoder = refiners_sam_vit_h.layer("PositionalEncoder", PositionalEncoder)
|
|
|
|
embed = pos_embed.reshape_as(positional_encoder.layer("Parameter", fl.Parameter).weight)
|
2023-12-08 09:12:37 +00:00
|
|
|
converted_source["PositionalEncoder.Parameter.weight"] = embed # type: ignore
|
2023-09-21 08:19:19 +00:00
|
|
|
converted_source.update(rel_items)
|
|
|
|
|
|
|
|
refiners_sam_vit_h.load_state_dict(state_dict=converted_source)
|
2023-10-19 08:17:25 +00:00
|
|
|
assert converter.compare_models((x,), threshold=1e-2)
|
2023-09-21 08:19:19 +00:00
|
|
|
|
|
|
|
return converted_source
|
|
|
|
|
|
|
|
|
|
|
|
def convert_mask_decoder(mask_decoder: nn.Module) -> dict[str, Tensor]:
|
|
|
|
manual_seed(seed=0)
|
|
|
|
|
|
|
|
refiners_mask_decoder = MaskDecoder()
|
|
|
|
|
|
|
|
image_embedding = torch.randn(1, 256, 64, 64)
|
|
|
|
dense_positional_embedding = torch.randn(1, 256, 64, 64)
|
|
|
|
point_embedding = torch.randn(1, 3, 256)
|
|
|
|
mask_embedding = torch.randn(1, 256, 64, 64)
|
|
|
|
|
|
|
|
from segment_anything.modeling.common import LayerNorm2d # type: ignore
|
|
|
|
|
2023-12-11 10:46:38 +00:00
|
|
|
import refiners.fluxion.layers as fl
|
|
|
|
|
2023-09-21 08:19:19 +00:00
|
|
|
assert issubclass(LayerNorm2d, nn.Module)
|
|
|
|
custom_layers = {LayerNorm2d: fl.LayerNorm2d}
|
|
|
|
|
|
|
|
converter = ModelConverter(
|
|
|
|
source_model=mask_decoder,
|
|
|
|
target_model=refiners_mask_decoder,
|
|
|
|
custom_layer_mapping=custom_layers, # type: ignore
|
|
|
|
)
|
|
|
|
|
|
|
|
inputs = {
|
|
|
|
"image_embeddings": image_embedding,
|
|
|
|
"image_pe": dense_positional_embedding,
|
|
|
|
"sparse_prompt_embeddings": point_embedding,
|
|
|
|
"dense_prompt_embeddings": mask_embedding,
|
|
|
|
"multimask_output": True,
|
|
|
|
}
|
|
|
|
|
|
|
|
refiners_mask_decoder.set_image_embedding(image_embedding)
|
|
|
|
refiners_mask_decoder.set_point_embedding(point_embedding)
|
|
|
|
refiners_mask_decoder.set_mask_embedding(mask_embedding)
|
|
|
|
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
|
|
|
|
|
|
|
|
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
|
|
|
|
assert mapping is not None
|
2024-03-21 13:59:36 +00:00
|
|
|
mapping["MaskDecoderTokens.Parameter"] = "iou_token"
|
2023-09-21 08:19:19 +00:00
|
|
|
|
2023-12-08 11:26:50 +00:00
|
|
|
state_dict = converter._convert_state_dict( # type: ignore
|
|
|
|
source_state_dict=mask_decoder.state_dict(),
|
|
|
|
target_state_dict=refiners_mask_decoder.state_dict(),
|
|
|
|
state_dict_mapping=mapping,
|
|
|
|
)
|
2024-03-21 13:59:36 +00:00
|
|
|
state_dict["MaskDecoderTokens.Parameter.weight"] = torch.cat(
|
2023-12-08 11:26:50 +00:00
|
|
|
tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0
|
|
|
|
) # type: ignore
|
2023-09-21 08:19:19 +00:00
|
|
|
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
|
|
|
|
|
|
|
|
refiners_mask_decoder.set_image_embedding(image_embedding)
|
|
|
|
refiners_mask_decoder.set_point_embedding(point_embedding)
|
|
|
|
refiners_mask_decoder.set_mask_embedding(mask_embedding)
|
|
|
|
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
|
|
|
|
|
|
|
|
# Perform (1) upscaling then (2) mask prediction in this order (= like in the official implementation) to make
|
|
|
|
# `compare_models` happy (MaskPrediction's Matmul runs those in the reverse order by default)
|
|
|
|
matmul = refiners_mask_decoder.ensure_find(fl.Matmul)
|
|
|
|
|
|
|
|
def forward_swapped_order(self: Any, *args: Any) -> Any:
|
|
|
|
y = self[1](*args) # (1)
|
|
|
|
x = self[0](*args) # (2)
|
|
|
|
return torch.matmul(input=x, other=y)
|
|
|
|
|
|
|
|
matmul.forward = types.MethodType(forward_swapped_order, matmul)
|
|
|
|
|
|
|
|
assert converter.compare_models(source_args=inputs, target_args={}, threshold=1e-3)
|
|
|
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
|
parser = argparse.ArgumentParser(description="Converts a Segment Anything ViT model to a Refiners SAMViTH model")
|
|
|
|
parser.add_argument(
|
|
|
|
"--from",
|
|
|
|
type=str,
|
|
|
|
dest="source_path",
|
|
|
|
default="sam_vit_h_4b8939.pth",
|
|
|
|
# required=True,
|
|
|
|
help="Path to the Segment Anything model weights",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--to",
|
|
|
|
type=str,
|
|
|
|
dest="output_path",
|
|
|
|
default="segment-anything-h.safetensors",
|
|
|
|
help="Output path for converted model (as safetensors).",
|
|
|
|
)
|
|
|
|
parser.add_argument("--half", action="store_true", default=False, help="Convert to half precision. Default: False")
|
|
|
|
parser.add_argument(
|
|
|
|
"--verbose",
|
|
|
|
action="store_true",
|
|
|
|
default=False,
|
|
|
|
help="Prints additional information during conversion. Default: False",
|
|
|
|
)
|
|
|
|
args = parser.parse_args(namespace=Args())
|
|
|
|
|
|
|
|
sam_h = build_sam_vit_h() # type: ignore
|
2024-01-19 15:37:01 +00:00
|
|
|
sam_h.load_state_dict(state_dict=load_tensors(args.source_path))
|
2023-09-21 08:19:19 +00:00
|
|
|
|
|
|
|
vit_state_dict = convert_vit(vit=sam_h.image_encoder)
|
|
|
|
mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder)
|
|
|
|
point_encoder_state_dict = convert_point_encoder(prompt_encoder=sam_h.prompt_encoder)
|
|
|
|
mask_encoder_state_dict = convert_mask_encoder(prompt_encoder=sam_h.prompt_encoder)
|
|
|
|
|
|
|
|
output_state_dict = {
|
2024-03-21 13:59:36 +00:00
|
|
|
**{f"SAMViTH.{key}": value for key, value in vit_state_dict.items()},
|
|
|
|
**{f"MaskDecoder.{key}": value for key, value in mask_decoder_state_dict.items()},
|
|
|
|
**{f"PointEncoder.{key}": value for key, value in point_encoder_state_dict.items()},
|
|
|
|
**{f"MaskEncoder.{key}": value for key, value in mask_encoder_state_dict.items()},
|
2023-09-21 08:19:19 +00:00
|
|
|
}
|
|
|
|
if args.half:
|
|
|
|
output_state_dict = {key: value.half() for key, value in output_state_dict.items()}
|
|
|
|
|
|
|
|
save_to_safetensors(path=args.output_path, tensors=output_state_dict)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|