mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
segment-anything: fix class name typo
Note: weights are impacted
This commit is contained in:
parent
64b52b407f
commit
feff4c78ae
|
@ -581,7 +581,7 @@ def convert_sam():
|
||||||
"convert_segment_anything.py",
|
"convert_segment_anything.py",
|
||||||
"tests/weights/sam_vit_h_4b8939.pth",
|
"tests/weights/sam_vit_h_4b8939.pth",
|
||||||
"tests/weights/segment-anything-h.safetensors",
|
"tests/weights/segment-anything-h.safetensors",
|
||||||
expected_hash="3b73b2fd",
|
expected_hash="b62ad5ed",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.context import Contexts
|
from refiners.fluxion.context import Contexts
|
||||||
from refiners.foundationals.segment_anything.transformer import (
|
from refiners.foundationals.segment_anything.transformer import (
|
||||||
SparseCrossDenseAttention,
|
SparseCrossDenseAttention,
|
||||||
TwoWayTranformerLayer,
|
TwoWayTransformerLayer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -210,7 +210,7 @@ class MaskDecoder(fl.Chain):
|
||||||
EmbeddingsAggregator(num_output_mask=num_output_mask),
|
EmbeddingsAggregator(num_output_mask=num_output_mask),
|
||||||
Transformer(
|
Transformer(
|
||||||
*(
|
*(
|
||||||
TwoWayTranformerLayer(
|
TwoWayTransformerLayer(
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
num_heads=8,
|
num_heads=8,
|
||||||
feed_forward_dim=feed_forward_dim,
|
feed_forward_dim=feed_forward_dim,
|
||||||
|
|
|
@ -116,7 +116,7 @@ class DenseCrossSparseAttention(fl.Chain):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TwoWayTranformerLayer(fl.Chain):
|
class TwoWayTransformerLayer(fl.Chain):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
|
|
|
@ -21,7 +21,7 @@ from refiners.fluxion.model_converter import ModelConverter
|
||||||
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
|
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
|
||||||
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
|
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
|
||||||
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
||||||
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
|
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
|
||||||
|
|
||||||
# See predictor_example.ipynb official notebook
|
# See predictor_example.ipynb official notebook
|
||||||
PROMPTS: list[SAMPrompt] = [
|
PROMPTS: list[SAMPrompt] = [
|
||||||
|
@ -188,7 +188,7 @@ def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
|
||||||
dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
|
dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
|
||||||
sparse_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device)
|
sparse_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device)
|
||||||
|
|
||||||
refiners_layer = TwoWayTranformerLayer(
|
refiners_layer = TwoWayTransformerLayer(
|
||||||
embedding_dim=256, feed_forward_dim=2048, num_heads=8, device=facebook_sam_h.device
|
embedding_dim=256, feed_forward_dim=2048, num_heads=8, device=facebook_sam_h.device
|
||||||
)
|
)
|
||||||
facebook_layer = facebook_sam_h.mask_decoder.transformer.layers[1] # type: ignore
|
facebook_layer = facebook_sam_h.mask_decoder.transformer.layers[1] # type: ignore
|
||||||
|
|
Loading…
Reference in a new issue