mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 06:38:45 +00:00
hq-sam : weights/load_weights
This commit is contained in:
parent
2b48988c07
commit
e8fcab9d17
|
@ -293,6 +293,8 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
|
|||
See [[arXiv:2306.01567] Segment Anything in High Quality](https://arxiv.org/abs/2306.01567) for details.
|
||||
"""
|
||||
|
||||
_adapter_modules: dict[str, fl.Module] = {}
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"hq_sam": {"early_vit_embedding": None}}
|
||||
|
||||
|
@ -318,27 +320,40 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
|
|||
mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype
|
||||
)
|
||||
]
|
||||
self.register_adapter_module("Chain.HQSAMMaskPrediction", self.mask_prediction_adapter.hq_sam_mask_prediction)
|
||||
|
||||
self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)]
|
||||
self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)]
|
||||
|
||||
mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens)
|
||||
self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)]
|
||||
self.register_adapter_module("MaskDecoderTokensExtender.hq_token", self.mask_decoder_tokens_extender.hq_token)
|
||||
|
||||
if weights is not None:
|
||||
hq_token_prefix = "MaskDecoderTokensExtender.hq_token."
|
||||
hq_token_state_dict: dict[str, torch.Tensor] = {
|
||||
k.removeprefix(hq_token_prefix): v for k, v in weights.items() if k.startswith(hq_token_prefix)
|
||||
}
|
||||
self.mask_decoder_tokens_extender.hq_token.load_state_dict(hq_token_state_dict)
|
||||
|
||||
mask_pred_prefix = "Chain.HQSAMMaskPrediction."
|
||||
mask_pred_state_dict: dict[str, torch.Tensor] = {
|
||||
k.removeprefix(mask_pred_prefix): v for k, v in weights.items() if k.startswith(mask_pred_prefix)
|
||||
}
|
||||
self.mask_prediction_adapter.hq_sam_mask_prediction.load_state_dict(mask_pred_state_dict)
|
||||
self.load_weights(weights)
|
||||
|
||||
self.to(device=target.device, dtype=target.dtype)
|
||||
|
||||
def register_adapter_module(self, module_key: str, adapter_module: fl.Module):
|
||||
"""Register the modules used in weights and load_weights."""
|
||||
self._adapter_modules[module_key] = adapter_module
|
||||
|
||||
@property
|
||||
def weights(self) -> dict[str, torch.Tensor]:
|
||||
out: dict[str, torch.Tensor] = {}
|
||||
for module_key in self._adapter_modules:
|
||||
adapter_module = self._adapter_modules[module_key]
|
||||
state_dict = adapter_module.state_dict()
|
||||
out.update({f"{module_key}.{key}": state_dict[key] for key in state_dict})
|
||||
return out
|
||||
|
||||
def load_weights(self, weights: dict[str, torch.Tensor]) -> None:
|
||||
for module_key in self._adapter_modules:
|
||||
state_dict: dict[str, torch.Tensor] = {
|
||||
k.removeprefix(f"{module_key}."): v for k, v in weights.items() if k.startswith(f"{module_key}.")
|
||||
}
|
||||
self._adapter_modules[module_key].load_state_dict(state_dict)
|
||||
|
||||
@property
|
||||
def mask_decoder_tokens_extender(self) -> MaskDecoderTokensExtender:
|
||||
return self._mask_decoder_tokens_extender[0]
|
||||
|
|
|
@ -346,3 +346,19 @@ def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -
|
|||
assert mask_prediction.shape == (batch_size, 1, 256, 256)
|
||||
assert iou_prediction.shape == (batch_size, 1)
|
||||
assert torch.equal(mask_prediction[0], mask_prediction[1])
|
||||
|
||||
|
||||
def test_hq_sam_load_save_weights(sam_h: SegmentAnythingH, hq_adapter_weights: Path, test_device: torch.device) -> None:
|
||||
weights = load_from_safetensors(hq_adapter_weights, device=test_device)
|
||||
|
||||
hq_sam_adapter = HQSAMAdapter(sam_h)
|
||||
out_weights_init = hq_sam_adapter.weights
|
||||
|
||||
assert set(out_weights_init.keys()) == set(weights.keys())
|
||||
|
||||
hq_sam_adapter = HQSAMAdapter(sam_h, weights=weights)
|
||||
out_weights = hq_sam_adapter.weights
|
||||
|
||||
assert set(out_weights.keys()) == set(weights.keys())
|
||||
for key in out_weights.keys():
|
||||
assert torch.equal(out_weights[key], weights[key])
|
||||
|
|
Loading…
Reference in a new issue