mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
SAM MaskDecoder token slicing
This commit is contained in:
parent
a93ceff752
commit
94e8b9c23f
|
@ -56,7 +56,7 @@ class Hypernetworks(fl.Concatenate):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
*[
|
*[
|
||||||
fl.Chain(
|
fl.Chain(
|
||||||
fl.Slicing(dim=1, start=i + 1, end=i + 2),
|
fl.Slicing(dim=1, start=i, end=i + 1),
|
||||||
fl.MultiLinear(
|
fl.MultiLinear(
|
||||||
input_dim=embedding_dim,
|
input_dim=embedding_dim,
|
||||||
output_dim=embedding_dim // 8,
|
output_dim=embedding_dim // 8,
|
||||||
|
@ -147,6 +147,8 @@ class MaskPrediction(fl.Chain):
|
||||||
start_mask, num_masks = (1, num_mask_tokens - 1) if multimask_output else (0, 1)
|
start_mask, num_masks = (1, num_mask_tokens - 1) if multimask_output else (0, 1)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
# rm unused tokens : 1st token (iou token) + last tokens (prompt tokens)
|
||||||
|
fl.Slicing(dim=1, start=1, end=num_mask_tokens + 1),
|
||||||
fl.Matmul(
|
fl.Matmul(
|
||||||
input=Hypernetworks(
|
input=Hypernetworks(
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
|
@ -177,7 +179,8 @@ class IOUPrediction(fl.Chain):
|
||||||
self.multimask_output = multimask_output
|
self.multimask_output = multimask_output
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Slicing(dim=1, start=0, end=1),
|
fl.Slicing(dim=1, start=0, end=1), # iou_token
|
||||||
|
fl.Squeeze(dim=1),
|
||||||
fl.MultiLinear(
|
fl.MultiLinear(
|
||||||
input_dim=embedding_dim,
|
input_dim=embedding_dim,
|
||||||
output_dim=num_mask_tokens,
|
output_dim=num_mask_tokens,
|
||||||
|
@ -187,7 +190,6 @@ class IOUPrediction(fl.Chain):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
fl.Slicing(dim=-1, start=1) if multimask_output else fl.Slicing(dim=-1, start=0, end=1),
|
fl.Slicing(dim=-1, start=1) if multimask_output else fl.Slicing(dim=-1, start=0, end=1),
|
||||||
fl.Squeeze(dim=1),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue