diff --git a/src/refiners/fluxion/layers/basics.py b/src/refiners/fluxion/layers/basics.py index 872d68b..aa76694 100644 --- a/src/refiners/fluxion/layers/basics.py +++ b/src/refiners/fluxion/layers/basics.py @@ -84,7 +84,7 @@ class Permute(Module): class Slicing(Module): - def __init__(self, dim: int = 0, start: int = 0, end: int = -1, step: int = 1) -> None: + def __init__(self, dim: int = 0, start: int = 0, end: int | None = None, step: int = 1) -> None: super().__init__() self.dim = dim self.start = start @@ -94,7 +94,8 @@ class Slicing(Module): def forward(self, x: Tensor) -> Tensor: dim_size = x.shape[self.dim] start = self.start if self.start >= 0 else dim_size + self.start - end = self.end if self.end >= 0 else dim_size + self.end + end = self.end or dim_size + end = end if end >= 0 else dim_size + end start = max(min(start, dim_size), 0) end = max(min(end, dim_size), 0) if start >= end: diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index f34c005..9c4bac0 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -268,9 +268,7 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): InjectionPoint(), # Wk ), fl.Chain( - fl.Slicing( - dim=1, start=text_sequence_length, end=text_sequence_length + image_sequence_length - ), + fl.Slicing(dim=1, start=text_sequence_length), fl.Linear( in_features=self.target.key_embedding_dim, out_features=self.target.inner_dim, @@ -286,9 +284,7 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): InjectionPoint(), # Wv ), fl.Chain( - fl.Slicing( - dim=1, start=text_sequence_length, end=text_sequence_length + image_sequence_length - ), + fl.Slicing(dim=1, start=text_sequence_length), fl.Linear( in_features=self.target.key_embedding_dim, out_features=self.target.inner_dim, diff --git a/src/refiners/foundationals/segment_anything/mask_decoder.py b/src/refiners/foundationals/segment_anything/mask_decoder.py index cdfb013..cd19668 100644 --- a/src/refiners/foundationals/segment_anything/mask_decoder.py +++ b/src/refiners/foundationals/segment_anything/mask_decoder.py @@ -156,7 +156,7 @@ class MaskPrediction(fl.Chain): ), other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype), ), - fl.Slicing(dim=1, start=1, end=num_mask_tokens + 1), + fl.Slicing(dim=1, start=1), fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim), ) @@ -183,7 +183,7 @@ class IOUPrediction(fl.Chain): device=device, dtype=dtype, ), - fl.Slicing(dim=-1, start=1, end=num_mask_tokens + 1), + fl.Slicing(dim=-1, start=1), ) diff --git a/tests/fluxion/layers/test_basics.py b/tests/fluxion/layers/test_basics.py index 567ffb0..7fb6fb7 100644 --- a/tests/fluxion/layers/test_basics.py +++ b/tests/fluxion/layers/test_basics.py @@ -19,6 +19,14 @@ def test_slicing_negative_indices() -> None: assert torch.equal(sliced, expected) +def test_none_end_slicing() -> None: + x = torch.randn(2, 1000, 400) + slicing = Slicing(dim=1, start=1) + sliced = slicing(x) + expected = x[:, 1:, :] + assert torch.equal(sliced, expected) + + def test_slicing_step() -> None: x = torch.randn(5, 5, 5) slicing_layer = Slicing(dim=1, start=0, end=5, step=2) @@ -27,7 +35,7 @@ def test_slicing_step() -> None: assert torch.equal(sliced, expected) -def test_slicing_empty_slice(): +def test_slicing_empty_slice() -> None: x = torch.randn(5, 5, 5) slicing_layer = Slicing(dim=1, start=3, end=3) sliced = slicing_layer(x) @@ -35,7 +43,7 @@ def test_slicing_empty_slice(): assert torch.equal(sliced, expected) -def test_slicing_full_dimension(): +def test_slicing_full_dimension() -> None: x = torch.randn(5, 5, 5) slicing_layer = Slicing(dim=2, start=0, end=5) sliced = slicing_layer(x) @@ -43,7 +51,7 @@ def test_slicing_full_dimension(): assert torch.equal(sliced, expected) -def test_slicing_step_greater_than_range(): +def test_slicing_step_greater_than_range() -> None: x = torch.randn(5, 5, 5) slicing_layer = Slicing(dim=1, start=1, end=3, step=4) sliced = slicing_layer(x) @@ -51,7 +59,7 @@ def test_slicing_step_greater_than_range(): assert torch.equal(sliced, expected) -def test_slicing_reversed_start_end(): +def test_slicing_reversed_start_end() -> None: x = torch.randn(5, 5, 5) slicing_layer = Slicing(dim=1, start=4, end=2) sliced = slicing_layer(x) @@ -59,7 +67,7 @@ def test_slicing_reversed_start_end(): assert torch.equal(sliced, expected) -def test_slicing_out_of_bounds_indices(): +def test_slicing_out_of_bounds_indices() -> None: x = torch.randn(5, 5, 5) slicing_layer = Slicing(dim=1, start=-10, end=10) sliced = slicing_layer(x)