Change fl.Slicing API

This commit is contained in:
limiteinductive 2023-12-12 20:09:58 +01:00 committed by Benjamin Trom
parent 11b0ff6f8c
commit a7551e0392
6 changed files with 101 additions and 13 deletions

View file

@ -84,14 +84,31 @@ class Permute(Module):
class Slicing(Module): class Slicing(Module):
def __init__(self, dim: int, start: int, length: int) -> None: def __init__(self, dim: int = 0, start: int = 0, end: int = -1, step: int = 1) -> None:
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.start = start self.start = start
self.length = length self.end = end
self.step = step
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x.narrow(self.dim, self.start, self.length) dim_size = x.shape[self.dim]
start = self.start if self.start >= 0 else dim_size + self.start
end = self.end if self.end >= 0 else dim_size + self.end
start = max(min(start, dim_size), 0)
end = max(min(end, dim_size), 0)
if start >= end:
return self.get_empty_slice(x)
indices = torch.arange(start=start, end=end, step=self.step, device=x.device)
return x.index_select(self.dim, indices)
def get_empty_slice(self, x: Tensor) -> Tensor:
"""
Return an empty slice of the same shape as the input tensor to mimic PyTorch's slicing behavior.
"""
shape = list(x.shape)
shape[self.dim] = 0
return torch.empty(*shape, device=x.device)
class Squeeze(Module): class Squeeze(Module):

View file

@ -113,7 +113,7 @@ class Encoder(Chain):
), ),
Chain( Chain(
Conv2d(in_channels=8, out_channels=8, kernel_size=1, device=device, dtype=dtype), Conv2d(in_channels=8, out_channels=8, kernel_size=1, device=device, dtype=dtype),
Slicing(dim=1, start=0, length=4), Slicing(dim=1, end=4),
), ),
) )

View file

@ -264,11 +264,13 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
InjectionPoint(), # Wq InjectionPoint(), # Wq
fl.Parallel( fl.Parallel(
fl.Chain( fl.Chain(
fl.Slicing(dim=1, start=0, length=text_sequence_length), fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wk InjectionPoint(), # Wk
), ),
fl.Chain( fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length), fl.Slicing(
dim=1, start=text_sequence_length, end=text_sequence_length + image_sequence_length
),
fl.Linear( fl.Linear(
in_features=self.target.key_embedding_dim, in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim, out_features=self.target.inner_dim,
@ -280,11 +282,13 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
), ),
fl.Parallel( fl.Parallel(
fl.Chain( fl.Chain(
fl.Slicing(dim=1, start=0, length=text_sequence_length), fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wv InjectionPoint(), # Wv
), ),
fl.Chain( fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length), fl.Slicing(
dim=1, start=text_sequence_length, end=text_sequence_length + image_sequence_length
),
fl.Linear( fl.Linear(
in_features=self.target.key_embedding_dim, in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim, out_features=self.target.inner_dim,

View file

@ -86,7 +86,7 @@ class Controlnet(Passthrough):
self.scale = scale self.scale = scale
super().__init__( super().__init__(
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype), TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
Slicing(dim=1, start=0, length=4), # support inpainting Slicing(dim=1, end=4), # support inpainting
DownBlocks(in_channels=4, device=device, dtype=dtype), DownBlocks(in_channels=4, device=device, dtype=dtype),
MiddleBlock(device=device, dtype=dtype), MiddleBlock(device=device, dtype=dtype),
) )

View file

@ -60,7 +60,7 @@ class Hypernetworks(fl.Concatenate):
super().__init__( super().__init__(
*[ *[
fl.Chain( fl.Chain(
fl.Slicing(dim=1, start=i + 1, length=1), fl.Slicing(dim=1, start=i + 1, end=i + 2),
fl.MultiLinear( fl.MultiLinear(
input_dim=embedding_dim, input_dim=embedding_dim,
output_dim=embedding_dim // 8, output_dim=embedding_dim // 8,
@ -156,7 +156,7 @@ class MaskPrediction(fl.Chain):
), ),
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype), other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
), ),
fl.Slicing(dim=1, start=1, length=num_mask_tokens), fl.Slicing(dim=1, start=1, end=num_mask_tokens + 1),
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim), fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim),
) )
@ -173,7 +173,7 @@ class IOUPrediction(fl.Chain):
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.num_layers = num_layers self.num_layers = num_layers
super().__init__( super().__init__(
fl.Slicing(dim=1, start=0, length=1), fl.Slicing(dim=1, start=0, end=1),
fl.Squeeze(dim=0), fl.Squeeze(dim=0),
fl.MultiLinear( fl.MultiLinear(
input_dim=embedding_dim, input_dim=embedding_dim,
@ -183,7 +183,7 @@ class IOUPrediction(fl.Chain):
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),
fl.Slicing(dim=-1, start=1, length=num_mask_tokens), fl.Slicing(dim=-1, start=1, end=num_mask_tokens + 1),
) )

View file

@ -0,0 +1,67 @@
import torch
from refiners.fluxion.layers.basics import Slicing
def test_slicing_positive_indices() -> None:
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=0, start=2, end=5)
sliced = slicing_layer(x)
expected = x[2:5, :]
assert torch.equal(sliced, expected)
def test_slicing_negative_indices() -> None:
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=-3, end=-1)
sliced = slicing_layer(x)
expected = x[:, -3:-1]
assert torch.equal(sliced, expected)
def test_slicing_step() -> None:
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=0, end=5, step=2)
sliced = slicing_layer(x)
expected = x[:, 0:5:2]
assert torch.equal(sliced, expected)
def test_slicing_empty_slice():
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=3, end=3)
sliced = slicing_layer(x)
expected = x[:, 3:3]
assert torch.equal(sliced, expected)
def test_slicing_full_dimension():
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=2, start=0, end=5)
sliced = slicing_layer(x)
expected = x[:, :, :]
assert torch.equal(sliced, expected)
def test_slicing_step_greater_than_range():
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=1, end=3, step=4)
sliced = slicing_layer(x)
expected = x[:, 1:3:4]
assert torch.equal(sliced, expected)
def test_slicing_reversed_start_end():
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=4, end=2)
sliced = slicing_layer(x)
expected = x[:, 4:2]
assert torch.equal(sliced, expected)
def test_slicing_out_of_bounds_indices():
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=-10, end=10)
sliced = slicing_layer(x)
expected = x[:, -10:10]
assert torch.equal(sliced, expected)