hq sam: add constructor args to docstring

Additionally, mark `register_adapter_module` for internal use.
This commit is contained in:
Cédric Deltheil 2024-04-08 08:48:32 +00:00 committed by Cédric Deltheil
parent d05ebb8dd3
commit a2ee705783

View file

@ -316,6 +316,13 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
hq_mask_only: bool = False,
weights: dict[str, torch.Tensor] | None = None,
) -> None:
"""Initialize the adapter.
Args:
target: The SegmentAnything model to adapt.
hq_mask_only: Whether to output only the high-quality mask or use it for mask correction (by summing it with the base SAM mask).
weights: The weights of the HQSAMAdapter.
"""
self.vit_embedding_dim = target.image_encoder.embedding_dim
self.target_num_mask_tokens = target.mask_decoder.num_multimask_outputs + 2
@ -332,21 +339,21 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype
)
]
self.register_adapter_module("Chain.HQSAMMaskPrediction", self.mask_prediction_adapter.hq_sam_mask_prediction)
self._register_adapter_module("Chain.HQSAMMaskPrediction", self.mask_prediction_adapter.hq_sam_mask_prediction)
self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)]
self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)]
mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens)
self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)]
self.register_adapter_module("MaskDecoderTokensExtender.hq_token", self.mask_decoder_tokens_extender.hq_token)
self._register_adapter_module("MaskDecoderTokensExtender.hq_token", self.mask_decoder_tokens_extender.hq_token)
if weights is not None:
self.load_weights(weights)
self.to(device=target.device, dtype=target.dtype)
def register_adapter_module(self, module_key: str, adapter_module: fl.Module):
def _register_adapter_module(self, module_key: str, adapter_module: fl.Module):
"""Register the modules used in weights and load_weights."""
self._adapter_modules[module_key] = adapter_module