hq sam: add constructor args to docstring

Additionally, mark `register_adapter_module` for internal use.
This commit is contained in:
Cédric Deltheil 2024-04-08 08:48:32 +00:00 committed by Cédric Deltheil
parent d05ebb8dd3
commit a2ee705783

View file

@ -316,6 +316,13 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
hq_mask_only: bool = False, hq_mask_only: bool = False,
weights: dict[str, torch.Tensor] | None = None, weights: dict[str, torch.Tensor] | None = None,
) -> None: ) -> None:
"""Initialize the adapter.
Args:
target: The SegmentAnything model to adapt.
hq_mask_only: Whether to output only the high-quality mask or use it for mask correction (by summing it with the base SAM mask).
weights: The weights of the HQSAMAdapter.
"""
self.vit_embedding_dim = target.image_encoder.embedding_dim self.vit_embedding_dim = target.image_encoder.embedding_dim
self.target_num_mask_tokens = target.mask_decoder.num_multimask_outputs + 2 self.target_num_mask_tokens = target.mask_decoder.num_multimask_outputs + 2
@ -332,21 +339,21 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype
) )
] ]
self.register_adapter_module("Chain.HQSAMMaskPrediction", self.mask_prediction_adapter.hq_sam_mask_prediction) self._register_adapter_module("Chain.HQSAMMaskPrediction", self.mask_prediction_adapter.hq_sam_mask_prediction)
self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)] self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)]
self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)] self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)]
mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens) mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens)
self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)] self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)]
self.register_adapter_module("MaskDecoderTokensExtender.hq_token", self.mask_decoder_tokens_extender.hq_token) self._register_adapter_module("MaskDecoderTokensExtender.hq_token", self.mask_decoder_tokens_extender.hq_token)
if weights is not None: if weights is not None:
self.load_weights(weights) self.load_weights(weights)
self.to(device=target.device, dtype=target.dtype) self.to(device=target.device, dtype=target.dtype)
def register_adapter_module(self, module_key: str, adapter_module: fl.Module): def _register_adapter_module(self, module_key: str, adapter_module: fl.Module):
"""Register the modules used in weights and load_weights.""" """Register the modules used in weights and load_weights."""
self._adapter_modules[module_key] = adapter_module self._adapter_modules[module_key] = adapter_module