mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
segment-anything: fix class name typo
Note: weights are impacted
This commit is contained in:
parent
64b52b407f
commit
feff4c78ae
|
@ -581,7 +581,7 @@ def convert_sam():
|
|||
"convert_segment_anything.py",
|
||||
"tests/weights/sam_vit_h_4b8939.pth",
|
||||
"tests/weights/segment-anything-h.safetensors",
|
||||
expected_hash="3b73b2fd",
|
||||
expected_hash="b62ad5ed",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import refiners.fluxion.layers as fl
|
|||
from refiners.fluxion.context import Contexts
|
||||
from refiners.foundationals.segment_anything.transformer import (
|
||||
SparseCrossDenseAttention,
|
||||
TwoWayTranformerLayer,
|
||||
TwoWayTransformerLayer,
|
||||
)
|
||||
|
||||
|
||||
|
@ -210,7 +210,7 @@ class MaskDecoder(fl.Chain):
|
|||
EmbeddingsAggregator(num_output_mask=num_output_mask),
|
||||
Transformer(
|
||||
*(
|
||||
TwoWayTranformerLayer(
|
||||
TwoWayTransformerLayer(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=8,
|
||||
feed_forward_dim=feed_forward_dim,
|
||||
|
|
|
@ -116,7 +116,7 @@ class DenseCrossSparseAttention(fl.Chain):
|
|||
)
|
||||
|
||||
|
||||
class TwoWayTranformerLayer(fl.Chain):
|
||||
class TwoWayTransformerLayer(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
|
|
|
@ -21,7 +21,7 @@ from refiners.fluxion.model_converter import ModelConverter
|
|||
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
|
||||
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
|
||||
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
||||
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
|
||||
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
|
||||
|
||||
# See predictor_example.ipynb official notebook
|
||||
PROMPTS: list[SAMPrompt] = [
|
||||
|
@ -188,7 +188,7 @@ def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
|
|||
dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
|
||||
sparse_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device)
|
||||
|
||||
refiners_layer = TwoWayTranformerLayer(
|
||||
refiners_layer = TwoWayTransformerLayer(
|
||||
embedding_dim=256, feed_forward_dim=2048, num_heads=8, device=facebook_sam_h.device
|
||||
)
|
||||
facebook_layer = facebook_sam_h.mask_decoder.transformer.layers[1] # type: ignore
|
||||
|
|
Loading…
Reference in a new issue