mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
hq sam: add constructor args to docstring
Additionally, mark `register_adapter_module` for internal use.
This commit is contained in:
parent
d05ebb8dd3
commit
a2ee705783
|
@ -316,6 +316,13 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
|
|||
hq_mask_only: bool = False,
|
||||
weights: dict[str, torch.Tensor] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the adapter.
|
||||
|
||||
Args:
|
||||
target: The SegmentAnything model to adapt.
|
||||
hq_mask_only: Whether to output only the high-quality mask or use it for mask correction (by summing it with the base SAM mask).
|
||||
weights: The weights of the HQSAMAdapter.
|
||||
"""
|
||||
self.vit_embedding_dim = target.image_encoder.embedding_dim
|
||||
self.target_num_mask_tokens = target.mask_decoder.num_multimask_outputs + 2
|
||||
|
||||
|
@ -332,21 +339,21 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
|
|||
mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype
|
||||
)
|
||||
]
|
||||
self.register_adapter_module("Chain.HQSAMMaskPrediction", self.mask_prediction_adapter.hq_sam_mask_prediction)
|
||||
self._register_adapter_module("Chain.HQSAMMaskPrediction", self.mask_prediction_adapter.hq_sam_mask_prediction)
|
||||
|
||||
self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)]
|
||||
self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)]
|
||||
|
||||
mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens)
|
||||
self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)]
|
||||
self.register_adapter_module("MaskDecoderTokensExtender.hq_token", self.mask_decoder_tokens_extender.hq_token)
|
||||
self._register_adapter_module("MaskDecoderTokensExtender.hq_token", self.mask_decoder_tokens_extender.hq_token)
|
||||
|
||||
if weights is not None:
|
||||
self.load_weights(weights)
|
||||
|
||||
self.to(device=target.device, dtype=target.dtype)
|
||||
|
||||
def register_adapter_module(self, module_key: str, adapter_module: fl.Module):
|
||||
def _register_adapter_module(self, module_key: str, adapter_module: fl.Module):
|
||||
"""Register the modules used in weights and load_weights."""
|
||||
self._adapter_modules[module_key] = adapter_module
|
||||
|
||||
|
|
Loading…
Reference in a new issue