SAM init with mask_decoder after #325

This commit is contained in:
Pierre Colle 2024-03-22 21:50:27 +00:00
parent 5c937b184a
commit c6c2105e11
2 changed files with 16 additions and 2 deletions

View file

@ -268,8 +268,8 @@ class SegmentAnythingH(SegmentAnything):
if mask_decoder:
assert (
mask_decoder.multimask_output == multimask_output
), f"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output (${multimask_output})"
multimask_output is None or mask_decoder.multimask_output == multimask_output
), f"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output ({multimask_output})"
else:
mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()

View file

@ -21,6 +21,7 @@ from refiners.fluxion import manual_seed
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
@ -124,6 +125,19 @@ def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
assert torch.equal(input=y_1, other=y_2)
def test_mask_decoder_arg() -> None:
mask_decoder_default = MaskDecoder()
sam_h = SegmentAnythingH(mask_decoder=mask_decoder_default)
assert sam_h.mask_decoder == mask_decoder_default
def test_multimask_output_error() -> None:
mask_decoder_multimask_output = MaskDecoder(multimask_output=True)
with pytest.raises(AssertionError, match="multimask_output"):
SegmentAnythingH(mask_decoder=mask_decoder_multimask_output, multimask_output=False)
@no_grad()
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device)