segment-anything: fix class name typo

Note: weights are impacted
This commit is contained in:
Cédric Deltheil 2024-01-30 08:44:49 +00:00 committed by Cédric Deltheil
parent 64b52b407f
commit feff4c78ae
4 changed files with 6 additions and 6 deletions

View file

@ -581,7 +581,7 @@ def convert_sam():
"convert_segment_anything.py",
"tests/weights/sam_vit_h_4b8939.pth",
"tests/weights/segment-anything-h.safetensors",
expected_hash="3b73b2fd",
expected_hash="b62ad5ed",
)

View file

@ -5,7 +5,7 @@ import refiners.fluxion.layers as fl
from refiners.fluxion.context import Contexts
from refiners.foundationals.segment_anything.transformer import (
SparseCrossDenseAttention,
TwoWayTranformerLayer,
TwoWayTransformerLayer,
)
@ -210,7 +210,7 @@ class MaskDecoder(fl.Chain):
EmbeddingsAggregator(num_output_mask=num_output_mask),
Transformer(
*(
TwoWayTranformerLayer(
TwoWayTransformerLayer(
embedding_dim=embedding_dim,
num_heads=8,
feed_forward_dim=feed_forward_dim,

View file

@ -116,7 +116,7 @@ class DenseCrossSparseAttention(fl.Chain):
)
class TwoWayTranformerLayer(fl.Chain):
class TwoWayTransformerLayer(fl.Chain):
def __init__(
self,
embedding_dim: int,

View file

@ -21,7 +21,7 @@ from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
from refiners.foundationals.segment_anything.model import SegmentAnythingH
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
# See predictor_example.ipynb official notebook
PROMPTS: list[SAMPrompt] = [
@ -188,7 +188,7 @@ def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
sparse_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device)
refiners_layer = TwoWayTranformerLayer(
refiners_layer = TwoWayTransformerLayer(
embedding_dim=256, feed_forward_dim=2048, num_heads=8, device=facebook_sam_h.device
)
facebook_layer = facebook_sam_h.mask_decoder.transformer.layers[1] # type: ignore