diff --git a/src/refiners/foundationals/segment_anything/hq_sam.py b/src/refiners/foundationals/segment_anything/hq_sam.py index 9d54aec..f87e1e0 100644 --- a/src/refiners/foundationals/segment_anything/hq_sam.py +++ b/src/refiners/foundationals/segment_anything/hq_sam.py @@ -293,6 +293,8 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]): See [[arXiv:2306.01567] Segment Anything in High Quality](https://arxiv.org/abs/2306.01567) for details. """ + _adapter_modules: dict[str, fl.Module] = {} + def init_context(self) -> Contexts: return {"hq_sam": {"early_vit_embedding": None}} @@ -318,27 +320,40 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]): mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype ) ] + self.register_adapter_module("Chain.HQSAMMaskPrediction", self.mask_prediction_adapter.hq_sam_mask_prediction) + self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)] self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)] mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens) self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)] + self.register_adapter_module("MaskDecoderTokensExtender.hq_token", self.mask_decoder_tokens_extender.hq_token) if weights is not None: - hq_token_prefix = "MaskDecoderTokensExtender.hq_token." - hq_token_state_dict: dict[str, torch.Tensor] = { - k.removeprefix(hq_token_prefix): v for k, v in weights.items() if k.startswith(hq_token_prefix) - } - self.mask_decoder_tokens_extender.hq_token.load_state_dict(hq_token_state_dict) - - mask_pred_prefix = "Chain.HQSAMMaskPrediction." - mask_pred_state_dict: dict[str, torch.Tensor] = { - k.removeprefix(mask_pred_prefix): v for k, v in weights.items() if k.startswith(mask_pred_prefix) - } - self.mask_prediction_adapter.hq_sam_mask_prediction.load_state_dict(mask_pred_state_dict) + self.load_weights(weights) self.to(device=target.device, dtype=target.dtype) + def register_adapter_module(self, module_key: str, adapter_module: fl.Module): + """Register the modules used in weights and load_weights.""" + self._adapter_modules[module_key] = adapter_module + + @property + def weights(self) -> dict[str, torch.Tensor]: + out: dict[str, torch.Tensor] = {} + for module_key in self._adapter_modules: + adapter_module = self._adapter_modules[module_key] + state_dict = adapter_module.state_dict() + out.update({f"{module_key}.{key}": state_dict[key] for key in state_dict}) + return out + + def load_weights(self, weights: dict[str, torch.Tensor]) -> None: + for module_key in self._adapter_modules: + state_dict: dict[str, torch.Tensor] = { + k.removeprefix(f"{module_key}."): v for k, v in weights.items() if k.startswith(f"{module_key}.") + } + self._adapter_modules[module_key].load_state_dict(state_dict) + @property def mask_decoder_tokens_extender(self) -> MaskDecoderTokensExtender: return self._mask_decoder_tokens_extender[0] diff --git a/tests/foundationals/segment_anything/test_hq_sam.py b/tests/foundationals/segment_anything/test_hq_sam.py index 8615119..228fd55 100644 --- a/tests/foundationals/segment_anything/test_hq_sam.py +++ b/tests/foundationals/segment_anything/test_hq_sam.py @@ -346,3 +346,19 @@ def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) - assert mask_prediction.shape == (batch_size, 1, 256, 256) assert iou_prediction.shape == (batch_size, 1) assert torch.equal(mask_prediction[0], mask_prediction[1]) + + +def test_hq_sam_load_save_weights(sam_h: SegmentAnythingH, hq_adapter_weights: Path, test_device: torch.device) -> None: + weights = load_from_safetensors(hq_adapter_weights, device=test_device) + + hq_sam_adapter = HQSAMAdapter(sam_h) + out_weights_init = hq_sam_adapter.weights + + assert set(out_weights_init.keys()) == set(weights.keys()) + + hq_sam_adapter = HQSAMAdapter(sam_h, weights=weights) + out_weights = hq_sam_adapter.weights + + assert set(out_weights.keys()) == set(weights.keys()) + for key in out_weights.keys(): + assert torch.equal(out_weights[key], weights[key])