segment-anything: fix class name typo

Note: weights are impacted
This commit is contained in:
Cédric Deltheil 2024-01-30 08:44:49 +00:00 committed by Cédric Deltheil
parent 64b52b407f
commit feff4c78ae
4 changed files with 6 additions and 6 deletions

View file

@ -581,7 +581,7 @@ def convert_sam():
"convert_segment_anything.py", "convert_segment_anything.py",
"tests/weights/sam_vit_h_4b8939.pth", "tests/weights/sam_vit_h_4b8939.pth",
"tests/weights/segment-anything-h.safetensors", "tests/weights/segment-anything-h.safetensors",
expected_hash="3b73b2fd", expected_hash="b62ad5ed",
) )

View file

@ -5,7 +5,7 @@ import refiners.fluxion.layers as fl
from refiners.fluxion.context import Contexts from refiners.fluxion.context import Contexts
from refiners.foundationals.segment_anything.transformer import ( from refiners.foundationals.segment_anything.transformer import (
SparseCrossDenseAttention, SparseCrossDenseAttention,
TwoWayTranformerLayer, TwoWayTransformerLayer,
) )
@ -210,7 +210,7 @@ class MaskDecoder(fl.Chain):
EmbeddingsAggregator(num_output_mask=num_output_mask), EmbeddingsAggregator(num_output_mask=num_output_mask),
Transformer( Transformer(
*( *(
TwoWayTranformerLayer( TwoWayTransformerLayer(
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
num_heads=8, num_heads=8,
feed_forward_dim=feed_forward_dim, feed_forward_dim=feed_forward_dim,

View file

@ -116,7 +116,7 @@ class DenseCrossSparseAttention(fl.Chain):
) )
class TwoWayTranformerLayer(fl.Chain): class TwoWayTransformerLayer(fl.Chain):
def __init__( def __init__(
self, self,
embedding_dim: int, embedding_dim: int,

View file

@ -21,7 +21,7 @@ from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
from refiners.foundationals.segment_anything.model import SegmentAnythingH from refiners.foundationals.segment_anything.model import SegmentAnythingH
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
# See predictor_example.ipynb official notebook # See predictor_example.ipynb official notebook
PROMPTS: list[SAMPrompt] = [ PROMPTS: list[SAMPrompt] = [
@ -188,7 +188,7 @@ def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device) dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
sparse_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device) sparse_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device)
refiners_layer = TwoWayTranformerLayer( refiners_layer = TwoWayTransformerLayer(
embedding_dim=256, feed_forward_dim=2048, num_heads=8, device=facebook_sam_h.device embedding_dim=256, feed_forward_dim=2048, num_heads=8, device=facebook_sam_h.device
) )
facebook_layer = facebook_sam_h.mask_decoder.transformer.layers[1] # type: ignore facebook_layer = facebook_sam_h.mask_decoder.transformer.layers[1] # type: ignore