From a2ee70578361e4d84a65a8708564480a9b0ec67e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Mon, 8 Apr 2024 08:48:32 +0000 Subject: [PATCH] hq sam: add constructor args to docstring Additionally, mark `register_adapter_module` for internal use. --- .../foundationals/segment_anything/hq_sam.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/refiners/foundationals/segment_anything/hq_sam.py b/src/refiners/foundationals/segment_anything/hq_sam.py index 87129f5..4ffbbbb 100644 --- a/src/refiners/foundationals/segment_anything/hq_sam.py +++ b/src/refiners/foundationals/segment_anything/hq_sam.py @@ -316,6 +316,13 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]): hq_mask_only: bool = False, weights: dict[str, torch.Tensor] | None = None, ) -> None: + """Initialize the adapter. + + Args: + target: The SegmentAnything model to adapt. + hq_mask_only: Whether to output only the high-quality mask or use it for mask correction (by summing it with the base SAM mask). + weights: The weights of the HQSAMAdapter. + """ self.vit_embedding_dim = target.image_encoder.embedding_dim self.target_num_mask_tokens = target.mask_decoder.num_multimask_outputs + 2 @@ -332,21 +339,21 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]): mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype ) ] - self.register_adapter_module("Chain.HQSAMMaskPrediction", self.mask_prediction_adapter.hq_sam_mask_prediction) + self._register_adapter_module("Chain.HQSAMMaskPrediction", self.mask_prediction_adapter.hq_sam_mask_prediction) self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)] self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)] mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens) self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)] - self.register_adapter_module("MaskDecoderTokensExtender.hq_token", self.mask_decoder_tokens_extender.hq_token) + self._register_adapter_module("MaskDecoderTokensExtender.hq_token", self.mask_decoder_tokens_extender.hq_token) if weights is not None: self.load_weights(weights) self.to(device=target.device, dtype=target.dtype) - def register_adapter_module(self, module_key: str, adapter_module: fl.Module): + def _register_adapter_module(self, module_key: str, adapter_module: fl.Module): """Register the modules used in weights and load_weights.""" self._adapter_modules[module_key] = adapter_module