mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
SAM init with mask_decoder after #325
This commit is contained in:
parent
5c937b184a
commit
cba83b0558
|
@ -268,8 +268,8 @@ class SegmentAnythingH(SegmentAnything):
|
||||||
|
|
||||||
if mask_decoder:
|
if mask_decoder:
|
||||||
assert (
|
assert (
|
||||||
mask_decoder.multimask_output == multimask_output
|
multimask_output is None or mask_decoder.multimask_output == multimask_output
|
||||||
), f"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output (${multimask_output})"
|
), f"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output ({multimask_output})"
|
||||||
else:
|
else:
|
||||||
mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()
|
mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ from refiners.fluxion import manual_seed
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
|
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
|
||||||
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
|
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
|
||||||
|
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||||
from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH
|
from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH
|
||||||
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
|
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
|
||||||
|
|
||||||
|
@ -124,6 +125,19 @@ def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
|
||||||
assert torch.equal(input=y_1, other=y_2)
|
assert torch.equal(input=y_1, other=y_2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mask_decoder_arg() -> None:
|
||||||
|
mask_decoder_default = MaskDecoder()
|
||||||
|
sam_h = SegmentAnythingH(mask_decoder=mask_decoder_default)
|
||||||
|
|
||||||
|
assert sam_h.mask_decoder == mask_decoder_default
|
||||||
|
|
||||||
|
|
||||||
|
def test_multimask_output_error() -> None:
|
||||||
|
mask_decoder_multimask_output = MaskDecoder(multimask_output=True)
|
||||||
|
with pytest.raises(AssertionError, match="multimask_output"):
|
||||||
|
SegmentAnythingH(mask_decoder=mask_decoder_multimask_output, multimask_output=False)
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
|
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
|
||||||
image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device)
|
image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device)
|
||||||
|
|
Loading…
Reference in a new issue