hq-sam : weights/load_weights

This commit is contained in:
Pierre Colle 2024-03-28 14:21:12 +00:00
parent 2b48988c07
commit e8fcab9d17
2 changed files with 42 additions and 11 deletions

View file

@ -293,6 +293,8 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
See [[arXiv:2306.01567] Segment Anything in High Quality](https://arxiv.org/abs/2306.01567) for details. See [[arXiv:2306.01567] Segment Anything in High Quality](https://arxiv.org/abs/2306.01567) for details.
""" """
_adapter_modules: dict[str, fl.Module] = {}
def init_context(self) -> Contexts: def init_context(self) -> Contexts:
return {"hq_sam": {"early_vit_embedding": None}} return {"hq_sam": {"early_vit_embedding": None}}
@ -318,27 +320,40 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype
) )
] ]
self.register_adapter_module("Chain.HQSAMMaskPrediction", self.mask_prediction_adapter.hq_sam_mask_prediction)
self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)] self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)]
self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)] self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)]
mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens) mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens)
self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)] self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)]
self.register_adapter_module("MaskDecoderTokensExtender.hq_token", self.mask_decoder_tokens_extender.hq_token)
if weights is not None: if weights is not None:
hq_token_prefix = "MaskDecoderTokensExtender.hq_token." self.load_weights(weights)
hq_token_state_dict: dict[str, torch.Tensor] = {
k.removeprefix(hq_token_prefix): v for k, v in weights.items() if k.startswith(hq_token_prefix)
}
self.mask_decoder_tokens_extender.hq_token.load_state_dict(hq_token_state_dict)
mask_pred_prefix = "Chain.HQSAMMaskPrediction."
mask_pred_state_dict: dict[str, torch.Tensor] = {
k.removeprefix(mask_pred_prefix): v for k, v in weights.items() if k.startswith(mask_pred_prefix)
}
self.mask_prediction_adapter.hq_sam_mask_prediction.load_state_dict(mask_pred_state_dict)
self.to(device=target.device, dtype=target.dtype) self.to(device=target.device, dtype=target.dtype)
def register_adapter_module(self, module_key: str, adapter_module: fl.Module):
"""Register the modules used in weights and load_weights."""
self._adapter_modules[module_key] = adapter_module
@property
def weights(self) -> dict[str, torch.Tensor]:
out: dict[str, torch.Tensor] = {}
for module_key in self._adapter_modules:
adapter_module = self._adapter_modules[module_key]
state_dict = adapter_module.state_dict()
out.update({f"{module_key}.{key}": state_dict[key] for key in state_dict})
return out
def load_weights(self, weights: dict[str, torch.Tensor]) -> None:
for module_key in self._adapter_modules:
state_dict: dict[str, torch.Tensor] = {
k.removeprefix(f"{module_key}."): v for k, v in weights.items() if k.startswith(f"{module_key}.")
}
self._adapter_modules[module_key].load_state_dict(state_dict)
@property @property
def mask_decoder_tokens_extender(self) -> MaskDecoderTokensExtender: def mask_decoder_tokens_extender(self) -> MaskDecoderTokensExtender:
return self._mask_decoder_tokens_extender[0] return self._mask_decoder_tokens_extender[0]

View file

@ -346,3 +346,19 @@ def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -
assert mask_prediction.shape == (batch_size, 1, 256, 256) assert mask_prediction.shape == (batch_size, 1, 256, 256)
assert iou_prediction.shape == (batch_size, 1) assert iou_prediction.shape == (batch_size, 1)
assert torch.equal(mask_prediction[0], mask_prediction[1]) assert torch.equal(mask_prediction[0], mask_prediction[1])
def test_hq_sam_load_save_weights(sam_h: SegmentAnythingH, hq_adapter_weights: Path, test_device: torch.device) -> None:
weights = load_from_safetensors(hq_adapter_weights, device=test_device)
hq_sam_adapter = HQSAMAdapter(sam_h)
out_weights_init = hq_sam_adapter.weights
assert set(out_weights_init.keys()) == set(weights.keys())
hq_sam_adapter = HQSAMAdapter(sam_h, weights=weights)
out_weights = hq_sam_adapter.weights
assert set(out_weights.keys()) == set(weights.keys())
for key in out_weights.keys():
assert torch.equal(out_weights[key], weights[key])