SAM MaskDecoder token slicing

This commit is contained in:
Pierre Colle 2024-03-20 23:43:59 +00:00 committed by Colle
parent a93ceff752
commit 94e8b9c23f

View file

@ -56,7 +56,7 @@ class Hypernetworks(fl.Concatenate):
super().__init__( super().__init__(
*[ *[
fl.Chain( fl.Chain(
fl.Slicing(dim=1, start=i + 1, end=i + 2), fl.Slicing(dim=1, start=i, end=i + 1),
fl.MultiLinear( fl.MultiLinear(
input_dim=embedding_dim, input_dim=embedding_dim,
output_dim=embedding_dim // 8, output_dim=embedding_dim // 8,
@ -147,6 +147,8 @@ class MaskPrediction(fl.Chain):
start_mask, num_masks = (1, num_mask_tokens - 1) if multimask_output else (0, 1) start_mask, num_masks = (1, num_mask_tokens - 1) if multimask_output else (0, 1)
super().__init__( super().__init__(
# rm unused tokens : 1st token (iou token) + last tokens (prompt tokens)
fl.Slicing(dim=1, start=1, end=num_mask_tokens + 1),
fl.Matmul( fl.Matmul(
input=Hypernetworks( input=Hypernetworks(
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
@ -177,7 +179,8 @@ class IOUPrediction(fl.Chain):
self.multimask_output = multimask_output self.multimask_output = multimask_output
super().__init__( super().__init__(
fl.Slicing(dim=1, start=0, end=1), fl.Slicing(dim=1, start=0, end=1), # iou_token
fl.Squeeze(dim=1),
fl.MultiLinear( fl.MultiLinear(
input_dim=embedding_dim, input_dim=embedding_dim,
output_dim=num_mask_tokens, output_dim=num_mask_tokens,
@ -187,7 +190,6 @@ class IOUPrediction(fl.Chain):
dtype=dtype, dtype=dtype,
), ),
fl.Slicing(dim=-1, start=1) if multimask_output else fl.Slicing(dim=-1, start=0, end=1), fl.Slicing(dim=-1, start=1) if multimask_output else fl.Slicing(dim=-1, start=0, end=1),
fl.Squeeze(dim=1),
) )