mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-10 07:21:59 +00:00
237 lines
8.9 KiB
Python
237 lines
8.9 KiB
Python
|
import argparse
|
||
|
import types
|
||
|
from typing import Any, Callable, cast
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from torch import Tensor
|
||
|
|
||
|
import refiners.fluxion.layers as fl
|
||
|
from refiners.fluxion.model_converter import ModelConverter
|
||
|
from refiners.fluxion.utils import manual_seed, save_to_safetensors
|
||
|
from refiners.foundationals.segment_anything.image_encoder import SAMViTH
|
||
|
from refiners.foundationals.segment_anything.prompt_encoder import PointEncoder, MaskEncoder
|
||
|
|
||
|
from segment_anything import build_sam_vit_h # type: ignore
|
||
|
from segment_anything.modeling.common import LayerNorm2d # type: ignore
|
||
|
|
||
|
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||
|
|
||
|
|
||
|
class FacebookSAM(nn.Module):
|
||
|
image_encoder: nn.Module
|
||
|
prompt_encoder: nn.Module
|
||
|
mask_decoder: nn.Module
|
||
|
|
||
|
|
||
|
build_sam_vit_h = cast(Callable[[], FacebookSAM], build_sam_vit_h)
|
||
|
|
||
|
|
||
|
assert issubclass(LayerNorm2d, nn.Module)
|
||
|
custom_layers = {LayerNorm2d: fl.LayerNorm2d}
|
||
|
|
||
|
|
||
|
class Args(argparse.Namespace):
|
||
|
source_path: str
|
||
|
output_path: str
|
||
|
half: bool
|
||
|
verbose: bool
|
||
|
|
||
|
|
||
|
def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
|
||
|
state_dict: dict[str, Tensor] = {
|
||
|
"no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore
|
||
|
}
|
||
|
|
||
|
refiners_mask_encoder = MaskEncoder()
|
||
|
# TODO: handle other weights
|
||
|
refiners_mask_encoder.load_state_dict(state_dict=state_dict, strict=False)
|
||
|
|
||
|
return state_dict
|
||
|
|
||
|
|
||
|
def convert_point_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
|
||
|
manual_seed(seed=0)
|
||
|
point_embeddings: list[Tensor] = [pe.weight for pe in prompt_encoder.point_embeddings] + [prompt_encoder.not_a_point_embed.weight] # type: ignore
|
||
|
pe = prompt_encoder.pe_layer.positional_encoding_gaussian_matrix # type: ignore
|
||
|
assert isinstance(pe, Tensor)
|
||
|
state_dict: dict[str, Tensor] = {
|
||
|
"Residual.Chain.PointTypeEmbedding.weight": nn.Parameter(data=torch.cat(tensors=point_embeddings, dim=0)),
|
||
|
"CoordinateEncoder.Linear.weight": nn.Parameter(data=pe.T.contiguous()),
|
||
|
}
|
||
|
|
||
|
refiners_prompt_encoder = PointEncoder()
|
||
|
refiners_prompt_encoder.load_state_dict(state_dict=state_dict)
|
||
|
|
||
|
return state_dict
|
||
|
|
||
|
|
||
|
def convert_vit(vit: nn.Module) -> dict[str, Tensor]:
|
||
|
manual_seed(seed=0)
|
||
|
refiners_sam_vit_h = SAMViTH()
|
||
|
|
||
|
converter = ModelConverter(
|
||
|
source_model=vit,
|
||
|
target_model=refiners_sam_vit_h,
|
||
|
custom_layer_mapping=custom_layers, # type: ignore
|
||
|
)
|
||
|
converter.skip_init_check = True
|
||
|
|
||
|
x = torch.randn(1, 3, 1024, 1024)
|
||
|
mapping = converter.map_state_dicts(source_args=(x,))
|
||
|
assert mapping
|
||
|
|
||
|
mapping["PositionalEncoder.Chain.Parameter.parameter"] = "pos_embed"
|
||
|
|
||
|
target_state_dict = refiners_sam_vit_h.state_dict()
|
||
|
del target_state_dict["PositionalEncoder.Chain.Parameter.parameter"]
|
||
|
|
||
|
source_state_dict = vit.state_dict()
|
||
|
pos_embed = source_state_dict["pos_embed"]
|
||
|
del source_state_dict["pos_embed"]
|
||
|
|
||
|
target_rel_keys = [
|
||
|
(
|
||
|
f"Transformer.TransformerLayer_{i}.Residual_1.Chain.FusedSelfAttention.RelativePositionAttention.horizontal_embedding",
|
||
|
f"Transformer.TransformerLayer_{i}.Residual_1.Chain.FusedSelfAttention.RelativePositionAttention.vertical_embedding",
|
||
|
)
|
||
|
for i in range(1, 33)
|
||
|
]
|
||
|
source_rel_keys = [(f"blocks.{i}.attn.rel_pos_w", f"blocks.{i}.attn.rel_pos_h") for i in range(32)]
|
||
|
|
||
|
rel_items: dict[str, Tensor] = {}
|
||
|
|
||
|
for (key_w, key_h), (target_key_w, target_key_h) in zip(source_rel_keys, target_rel_keys):
|
||
|
rel_items[target_key_w] = source_state_dict[key_w]
|
||
|
rel_items[target_key_h] = source_state_dict[key_h]
|
||
|
del source_state_dict[key_w]
|
||
|
del source_state_dict[key_h]
|
||
|
del target_state_dict[target_key_w]
|
||
|
del target_state_dict[target_key_h]
|
||
|
|
||
|
converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
|
||
|
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
||
|
)
|
||
|
|
||
|
converted_source["PositionalEncoder.Chain.Parameter.parameter"] = pos_embed # type: ignore
|
||
|
converted_source.update(rel_items)
|
||
|
|
||
|
refiners_sam_vit_h.load_state_dict(state_dict=converted_source)
|
||
|
assert converter.compare_models((x,), threshold=1e-3)
|
||
|
|
||
|
return converted_source
|
||
|
|
||
|
|
||
|
def convert_mask_decoder(mask_decoder: nn.Module) -> dict[str, Tensor]:
|
||
|
manual_seed(seed=0)
|
||
|
|
||
|
refiners_mask_decoder = MaskDecoder()
|
||
|
|
||
|
image_embedding = torch.randn(1, 256, 64, 64)
|
||
|
dense_positional_embedding = torch.randn(1, 256, 64, 64)
|
||
|
point_embedding = torch.randn(1, 3, 256)
|
||
|
mask_embedding = torch.randn(1, 256, 64, 64)
|
||
|
|
||
|
import refiners.fluxion.layers as fl
|
||
|
from segment_anything.modeling.common import LayerNorm2d # type: ignore
|
||
|
|
||
|
assert issubclass(LayerNorm2d, nn.Module)
|
||
|
custom_layers = {LayerNorm2d: fl.LayerNorm2d}
|
||
|
|
||
|
converter = ModelConverter(
|
||
|
source_model=mask_decoder,
|
||
|
target_model=refiners_mask_decoder,
|
||
|
custom_layer_mapping=custom_layers, # type: ignore
|
||
|
)
|
||
|
|
||
|
inputs = {
|
||
|
"image_embeddings": image_embedding,
|
||
|
"image_pe": dense_positional_embedding,
|
||
|
"sparse_prompt_embeddings": point_embedding,
|
||
|
"dense_prompt_embeddings": mask_embedding,
|
||
|
"multimask_output": True,
|
||
|
}
|
||
|
|
||
|
refiners_mask_decoder.set_image_embedding(image_embedding)
|
||
|
refiners_mask_decoder.set_point_embedding(point_embedding)
|
||
|
refiners_mask_decoder.set_mask_embedding(mask_embedding)
|
||
|
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
|
||
|
|
||
|
mapping = converter.map_state_dicts(source_args=inputs, target_args={})
|
||
|
assert mapping is not None
|
||
|
mapping["IOUMaskEncoder"] = "iou_token"
|
||
|
|
||
|
state_dict = converter._convert_state_dict(source_state_dict=mask_decoder.state_dict(), target_state_dict=refiners_mask_decoder.state_dict(), state_dict_mapping=mapping) # type: ignore
|
||
|
state_dict["IOUMaskEncoder.weight"] = torch.cat(tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0) # type: ignore
|
||
|
|
||
|
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
|
||
|
|
||
|
refiners_mask_decoder.set_image_embedding(image_embedding)
|
||
|
refiners_mask_decoder.set_point_embedding(point_embedding)
|
||
|
refiners_mask_decoder.set_mask_embedding(mask_embedding)
|
||
|
refiners_mask_decoder.set_dense_positional_embedding(dense_positional_embedding)
|
||
|
|
||
|
# Perform (1) upscaling then (2) mask prediction in this order (= like in the official implementation) to make
|
||
|
# `compare_models` happy (MaskPrediction's Matmul runs those in the reverse order by default)
|
||
|
matmul = refiners_mask_decoder.ensure_find(fl.Matmul)
|
||
|
|
||
|
def forward_swapped_order(self: Any, *args: Any) -> Any:
|
||
|
y = self[1](*args) # (1)
|
||
|
x = self[0](*args) # (2)
|
||
|
return torch.matmul(input=x, other=y)
|
||
|
|
||
|
matmul.forward = types.MethodType(forward_swapped_order, matmul)
|
||
|
|
||
|
assert converter.compare_models(source_args=inputs, target_args={}, threshold=1e-3)
|
||
|
|
||
|
return state_dict
|
||
|
|
||
|
|
||
|
def main() -> None:
|
||
|
parser = argparse.ArgumentParser(description="Converts a Segment Anything ViT model to a Refiners SAMViTH model")
|
||
|
parser.add_argument(
|
||
|
"--from",
|
||
|
type=str,
|
||
|
dest="source_path",
|
||
|
default="sam_vit_h_4b8939.pth",
|
||
|
# required=True,
|
||
|
help="Path to the Segment Anything model weights",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--to",
|
||
|
type=str,
|
||
|
dest="output_path",
|
||
|
default="segment-anything-h.safetensors",
|
||
|
help="Output path for converted model (as safetensors).",
|
||
|
)
|
||
|
parser.add_argument("--half", action="store_true", default=False, help="Convert to half precision. Default: False")
|
||
|
parser.add_argument(
|
||
|
"--verbose",
|
||
|
action="store_true",
|
||
|
default=False,
|
||
|
help="Prints additional information during conversion. Default: False",
|
||
|
)
|
||
|
args = parser.parse_args(namespace=Args())
|
||
|
|
||
|
sam_h = build_sam_vit_h() # type: ignore
|
||
|
sam_h.load_state_dict(state_dict=torch.load(f=args.source_path)) # type: ignore
|
||
|
|
||
|
vit_state_dict = convert_vit(vit=sam_h.image_encoder)
|
||
|
mask_decoder_state_dict = convert_mask_decoder(mask_decoder=sam_h.mask_decoder)
|
||
|
point_encoder_state_dict = convert_point_encoder(prompt_encoder=sam_h.prompt_encoder)
|
||
|
mask_encoder_state_dict = convert_mask_encoder(prompt_encoder=sam_h.prompt_encoder)
|
||
|
|
||
|
output_state_dict = {
|
||
|
**{".".join(("image_encoder", key)): value for key, value in vit_state_dict.items()},
|
||
|
**{".".join(("mask_decoder", key)): value for key, value in mask_decoder_state_dict.items()},
|
||
|
**{".".join(("point_encoder", key)): value for key, value in point_encoder_state_dict.items()},
|
||
|
**{".".join(("mask_encoder", key)): value for key, value in mask_encoder_state_dict.items()},
|
||
|
}
|
||
|
if args.half:
|
||
|
output_state_dict = {key: value.half() for key, value in output_state_dict.items()}
|
||
|
|
||
|
save_to_safetensors(path=args.output_path, tensors=output_state_dict)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|