From 94e8b9c23f0b31f0113aa828b8e0d202eb38a7f6 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Wed, 20 Mar 2024 23:43:59 +0000 Subject: [PATCH] SAM MaskDecoder token slicing --- .../foundationals/segment_anything/mask_decoder.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/refiners/foundationals/segment_anything/mask_decoder.py b/src/refiners/foundationals/segment_anything/mask_decoder.py index 25b17aa..74ca86a 100644 --- a/src/refiners/foundationals/segment_anything/mask_decoder.py +++ b/src/refiners/foundationals/segment_anything/mask_decoder.py @@ -56,7 +56,7 @@ class Hypernetworks(fl.Concatenate): super().__init__( *[ fl.Chain( - fl.Slicing(dim=1, start=i + 1, end=i + 2), + fl.Slicing(dim=1, start=i, end=i + 1), fl.MultiLinear( input_dim=embedding_dim, output_dim=embedding_dim // 8, @@ -147,6 +147,8 @@ class MaskPrediction(fl.Chain): start_mask, num_masks = (1, num_mask_tokens - 1) if multimask_output else (0, 1) super().__init__( + # rm unused tokens : 1st token (iou token) + last tokens (prompt tokens) + fl.Slicing(dim=1, start=1, end=num_mask_tokens + 1), fl.Matmul( input=Hypernetworks( embedding_dim=embedding_dim, @@ -177,7 +179,8 @@ class IOUPrediction(fl.Chain): self.multimask_output = multimask_output super().__init__( - fl.Slicing(dim=1, start=0, end=1), + fl.Slicing(dim=1, start=0, end=1), # iou_token + fl.Squeeze(dim=1), fl.MultiLinear( input_dim=embedding_dim, output_dim=num_mask_tokens, @@ -187,7 +190,6 @@ class IOUPrediction(fl.Chain): dtype=dtype, ), fl.Slicing(dim=-1, start=1) if multimask_output else fl.Slicing(dim=-1, start=0, end=1), - fl.Squeeze(dim=1), )