mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
SAM init with mask_decoder after #325
This commit is contained in:
parent
5c937b184a
commit
cba83b0558
|
@ -268,8 +268,8 @@ class SegmentAnythingH(SegmentAnything):
|
|||
|
||||
if mask_decoder:
|
||||
assert (
|
||||
mask_decoder.multimask_output == multimask_output
|
||||
), f"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output (${multimask_output})"
|
||||
multimask_output is None or mask_decoder.multimask_output == multimask_output
|
||||
), f"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output ({multimask_output})"
|
||||
else:
|
||||
mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from refiners.fluxion import manual_seed
|
|||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
|
||||
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
|
||||
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||
from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH
|
||||
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
|
||||
|
||||
|
@ -124,6 +125,19 @@ def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
|
|||
assert torch.equal(input=y_1, other=y_2)
|
||||
|
||||
|
||||
def test_mask_decoder_arg() -> None:
|
||||
mask_decoder_default = MaskDecoder()
|
||||
sam_h = SegmentAnythingH(mask_decoder=mask_decoder_default)
|
||||
|
||||
assert sam_h.mask_decoder == mask_decoder_default
|
||||
|
||||
|
||||
def test_multimask_output_error() -> None:
|
||||
mask_decoder_multimask_output = MaskDecoder(multimask_output=True)
|
||||
with pytest.raises(AssertionError, match="multimask_output"):
|
||||
SegmentAnythingH(mask_decoder=mask_decoder_multimask_output, multimask_output=False)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
|
||||
image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device)
|
||||
|
|
Loading…
Reference in a new issue