mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 09:08:14 +00:00
hq sam: add constructor args to docstring
Additionally, mark `register_adapter_module` for internal use.
This commit is contained in:
parent
d05ebb8dd3
commit
a2ee705783
|
@ -316,6 +316,13 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
|
||||||
hq_mask_only: bool = False,
|
hq_mask_only: bool = False,
|
||||||
weights: dict[str, torch.Tensor] | None = None,
|
weights: dict[str, torch.Tensor] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: The SegmentAnything model to adapt.
|
||||||
|
hq_mask_only: Whether to output only the high-quality mask or use it for mask correction (by summing it with the base SAM mask).
|
||||||
|
weights: The weights of the HQSAMAdapter.
|
||||||
|
"""
|
||||||
self.vit_embedding_dim = target.image_encoder.embedding_dim
|
self.vit_embedding_dim = target.image_encoder.embedding_dim
|
||||||
self.target_num_mask_tokens = target.mask_decoder.num_multimask_outputs + 2
|
self.target_num_mask_tokens = target.mask_decoder.num_multimask_outputs + 2
|
||||||
|
|
||||||
|
@ -332,21 +339,21 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
|
||||||
mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype
|
mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
self.register_adapter_module("Chain.HQSAMMaskPrediction", self.mask_prediction_adapter.hq_sam_mask_prediction)
|
self._register_adapter_module("Chain.HQSAMMaskPrediction", self.mask_prediction_adapter.hq_sam_mask_prediction)
|
||||||
|
|
||||||
self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)]
|
self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)]
|
||||||
self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)]
|
self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)]
|
||||||
|
|
||||||
mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens)
|
mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens)
|
||||||
self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)]
|
self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)]
|
||||||
self.register_adapter_module("MaskDecoderTokensExtender.hq_token", self.mask_decoder_tokens_extender.hq_token)
|
self._register_adapter_module("MaskDecoderTokensExtender.hq_token", self.mask_decoder_tokens_extender.hq_token)
|
||||||
|
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
self.load_weights(weights)
|
self.load_weights(weights)
|
||||||
|
|
||||||
self.to(device=target.device, dtype=target.dtype)
|
self.to(device=target.device, dtype=target.dtype)
|
||||||
|
|
||||||
def register_adapter_module(self, module_key: str, adapter_module: fl.Module):
|
def _register_adapter_module(self, module_key: str, adapter_module: fl.Module):
|
||||||
"""Register the modules used in weights and load_weights."""
|
"""Register the modules used in weights and load_weights."""
|
||||||
self._adapter_modules[module_key] = adapter_module
|
self._adapter_modules[module_key] = adapter_module
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue