diff --git a/src/refiners/fluxion/layers/basics.py b/src/refiners/fluxion/layers/basics.py index 5f15050..872d68b 100644 --- a/src/refiners/fluxion/layers/basics.py +++ b/src/refiners/fluxion/layers/basics.py @@ -84,14 +84,31 @@ class Permute(Module): class Slicing(Module): - def __init__(self, dim: int, start: int, length: int) -> None: + def __init__(self, dim: int = 0, start: int = 0, end: int = -1, step: int = 1) -> None: super().__init__() self.dim = dim self.start = start - self.length = length + self.end = end + self.step = step def forward(self, x: Tensor) -> Tensor: - return x.narrow(self.dim, self.start, self.length) + dim_size = x.shape[self.dim] + start = self.start if self.start >= 0 else dim_size + self.start + end = self.end if self.end >= 0 else dim_size + self.end + start = max(min(start, dim_size), 0) + end = max(min(end, dim_size), 0) + if start >= end: + return self.get_empty_slice(x) + indices = torch.arange(start=start, end=end, step=self.step, device=x.device) + return x.index_select(self.dim, indices) + + def get_empty_slice(self, x: Tensor) -> Tensor: + """ + Return an empty slice of the same shape as the input tensor to mimic PyTorch's slicing behavior. + """ + shape = list(x.shape) + shape[self.dim] = 0 + return torch.empty(*shape, device=x.device) class Squeeze(Module): diff --git a/src/refiners/foundationals/latent_diffusion/auto_encoder.py b/src/refiners/foundationals/latent_diffusion/auto_encoder.py index 2dc3bd5..9fee47d 100644 --- a/src/refiners/foundationals/latent_diffusion/auto_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/auto_encoder.py @@ -113,7 +113,7 @@ class Encoder(Chain): ), Chain( Conv2d(in_channels=8, out_channels=8, kernel_size=1, device=device, dtype=dtype), - Slicing(dim=1, start=0, length=4), + Slicing(dim=1, end=4), ), ) diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index c25b82b..f34c005 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -264,11 +264,13 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): InjectionPoint(), # Wq fl.Parallel( fl.Chain( - fl.Slicing(dim=1, start=0, length=text_sequence_length), + fl.Slicing(dim=1, end=text_sequence_length), InjectionPoint(), # Wk ), fl.Chain( - fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length), + fl.Slicing( + dim=1, start=text_sequence_length, end=text_sequence_length + image_sequence_length + ), fl.Linear( in_features=self.target.key_embedding_dim, out_features=self.target.inner_dim, @@ -280,11 +282,13 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): ), fl.Parallel( fl.Chain( - fl.Slicing(dim=1, start=0, length=text_sequence_length), + fl.Slicing(dim=1, end=text_sequence_length), InjectionPoint(), # Wv ), fl.Chain( - fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length), + fl.Slicing( + dim=1, start=text_sequence_length, end=text_sequence_length + image_sequence_length + ), fl.Linear( in_features=self.target.key_embedding_dim, out_features=self.target.inner_dim, diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py index 21e2f50..1b8b9ff 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py @@ -86,7 +86,7 @@ class Controlnet(Passthrough): self.scale = scale super().__init__( TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype), - Slicing(dim=1, start=0, length=4), # support inpainting + Slicing(dim=1, end=4), # support inpainting DownBlocks(in_channels=4, device=device, dtype=dtype), MiddleBlock(device=device, dtype=dtype), ) diff --git a/src/refiners/foundationals/segment_anything/mask_decoder.py b/src/refiners/foundationals/segment_anything/mask_decoder.py index b0ee47d..cdfb013 100644 --- a/src/refiners/foundationals/segment_anything/mask_decoder.py +++ b/src/refiners/foundationals/segment_anything/mask_decoder.py @@ -60,7 +60,7 @@ class Hypernetworks(fl.Concatenate): super().__init__( *[ fl.Chain( - fl.Slicing(dim=1, start=i + 1, length=1), + fl.Slicing(dim=1, start=i + 1, end=i + 2), fl.MultiLinear( input_dim=embedding_dim, output_dim=embedding_dim // 8, @@ -156,7 +156,7 @@ class MaskPrediction(fl.Chain): ), other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype), ), - fl.Slicing(dim=1, start=1, length=num_mask_tokens), + fl.Slicing(dim=1, start=1, end=num_mask_tokens + 1), fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim), ) @@ -173,7 +173,7 @@ class IOUPrediction(fl.Chain): self.embedding_dim = embedding_dim self.num_layers = num_layers super().__init__( - fl.Slicing(dim=1, start=0, length=1), + fl.Slicing(dim=1, start=0, end=1), fl.Squeeze(dim=0), fl.MultiLinear( input_dim=embedding_dim, @@ -183,7 +183,7 @@ class IOUPrediction(fl.Chain): device=device, dtype=dtype, ), - fl.Slicing(dim=-1, start=1, length=num_mask_tokens), + fl.Slicing(dim=-1, start=1, end=num_mask_tokens + 1), ) diff --git a/tests/fluxion/layers/test_basics.py b/tests/fluxion/layers/test_basics.py new file mode 100644 index 0000000..567ffb0 --- /dev/null +++ b/tests/fluxion/layers/test_basics.py @@ -0,0 +1,67 @@ +import torch + +from refiners.fluxion.layers.basics import Slicing + + +def test_slicing_positive_indices() -> None: + x = torch.randn(5, 5, 5) + slicing_layer = Slicing(dim=0, start=2, end=5) + sliced = slicing_layer(x) + expected = x[2:5, :] + assert torch.equal(sliced, expected) + + +def test_slicing_negative_indices() -> None: + x = torch.randn(5, 5, 5) + slicing_layer = Slicing(dim=1, start=-3, end=-1) + sliced = slicing_layer(x) + expected = x[:, -3:-1] + assert torch.equal(sliced, expected) + + +def test_slicing_step() -> None: + x = torch.randn(5, 5, 5) + slicing_layer = Slicing(dim=1, start=0, end=5, step=2) + sliced = slicing_layer(x) + expected = x[:, 0:5:2] + assert torch.equal(sliced, expected) + + +def test_slicing_empty_slice(): + x = torch.randn(5, 5, 5) + slicing_layer = Slicing(dim=1, start=3, end=3) + sliced = slicing_layer(x) + expected = x[:, 3:3] + assert torch.equal(sliced, expected) + + +def test_slicing_full_dimension(): + x = torch.randn(5, 5, 5) + slicing_layer = Slicing(dim=2, start=0, end=5) + sliced = slicing_layer(x) + expected = x[:, :, :] + assert torch.equal(sliced, expected) + + +def test_slicing_step_greater_than_range(): + x = torch.randn(5, 5, 5) + slicing_layer = Slicing(dim=1, start=1, end=3, step=4) + sliced = slicing_layer(x) + expected = x[:, 1:3:4] + assert torch.equal(sliced, expected) + + +def test_slicing_reversed_start_end(): + x = torch.randn(5, 5, 5) + slicing_layer = Slicing(dim=1, start=4, end=2) + sliced = slicing_layer(x) + expected = x[:, 4:2] + assert torch.equal(sliced, expected) + + +def test_slicing_out_of_bounds_indices(): + x = torch.randn(5, 5, 5) + slicing_layer = Slicing(dim=1, start=-10, end=10) + sliced = slicing_layer(x) + expected = x[:, -10:10] + assert torch.equal(sliced, expected)