From feff4c78aebc216562f136d15c62dbf476031329 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Tue, 30 Jan 2024 08:44:49 +0000 Subject: [PATCH] segment-anything: fix class name typo Note: weights are impacted --- scripts/prepare_test_weights.py | 2 +- src/refiners/foundationals/segment_anything/mask_decoder.py | 4 ++-- src/refiners/foundationals/segment_anything/transformer.py | 2 +- tests/foundationals/segment_anything/test_sam.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index 19f9423..f611e1e 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -581,7 +581,7 @@ def convert_sam(): "convert_segment_anything.py", "tests/weights/sam_vit_h_4b8939.pth", "tests/weights/segment-anything-h.safetensors", - expected_hash="3b73b2fd", + expected_hash="b62ad5ed", ) diff --git a/src/refiners/foundationals/segment_anything/mask_decoder.py b/src/refiners/foundationals/segment_anything/mask_decoder.py index cd19668..0f9c477 100644 --- a/src/refiners/foundationals/segment_anything/mask_decoder.py +++ b/src/refiners/foundationals/segment_anything/mask_decoder.py @@ -5,7 +5,7 @@ import refiners.fluxion.layers as fl from refiners.fluxion.context import Contexts from refiners.foundationals.segment_anything.transformer import ( SparseCrossDenseAttention, - TwoWayTranformerLayer, + TwoWayTransformerLayer, ) @@ -210,7 +210,7 @@ class MaskDecoder(fl.Chain): EmbeddingsAggregator(num_output_mask=num_output_mask), Transformer( *( - TwoWayTranformerLayer( + TwoWayTransformerLayer( embedding_dim=embedding_dim, num_heads=8, feed_forward_dim=feed_forward_dim, diff --git a/src/refiners/foundationals/segment_anything/transformer.py b/src/refiners/foundationals/segment_anything/transformer.py index 5fb13d7..7abe2f5 100644 --- a/src/refiners/foundationals/segment_anything/transformer.py +++ b/src/refiners/foundationals/segment_anything/transformer.py @@ -116,7 +116,7 @@ class DenseCrossSparseAttention(fl.Chain): ) -class TwoWayTranformerLayer(fl.Chain): +class TwoWayTransformerLayer(fl.Chain): def __init__( self, embedding_dim: int, diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index 8b21c15..1a62217 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -21,7 +21,7 @@ from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention from refiners.foundationals.segment_anything.model import SegmentAnythingH -from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer +from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer # See predictor_example.ipynb official notebook PROMPTS: list[SAMPrompt] = [ @@ -188,7 +188,7 @@ def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None: dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device) sparse_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device) - refiners_layer = TwoWayTranformerLayer( + refiners_layer = TwoWayTransformerLayer( embedding_dim=256, feed_forward_dim=2048, num_heads=8, device=facebook_sam_h.device ) facebook_layer = facebook_sam_h.mask_decoder.transformer.layers[1] # type: ignore