mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
Change fl.Slicing API
This commit is contained in:
parent
11b0ff6f8c
commit
a7551e0392
|
@ -84,14 +84,31 @@ class Permute(Module):
|
|||
|
||||
|
||||
class Slicing(Module):
|
||||
def __init__(self, dim: int, start: int, length: int) -> None:
|
||||
def __init__(self, dim: int = 0, start: int = 0, end: int = -1, step: int = 1) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.start = start
|
||||
self.length = length
|
||||
self.end = end
|
||||
self.step = step
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.narrow(self.dim, self.start, self.length)
|
||||
dim_size = x.shape[self.dim]
|
||||
start = self.start if self.start >= 0 else dim_size + self.start
|
||||
end = self.end if self.end >= 0 else dim_size + self.end
|
||||
start = max(min(start, dim_size), 0)
|
||||
end = max(min(end, dim_size), 0)
|
||||
if start >= end:
|
||||
return self.get_empty_slice(x)
|
||||
indices = torch.arange(start=start, end=end, step=self.step, device=x.device)
|
||||
return x.index_select(self.dim, indices)
|
||||
|
||||
def get_empty_slice(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Return an empty slice of the same shape as the input tensor to mimic PyTorch's slicing behavior.
|
||||
"""
|
||||
shape = list(x.shape)
|
||||
shape[self.dim] = 0
|
||||
return torch.empty(*shape, device=x.device)
|
||||
|
||||
|
||||
class Squeeze(Module):
|
||||
|
|
|
@ -113,7 +113,7 @@ class Encoder(Chain):
|
|||
),
|
||||
Chain(
|
||||
Conv2d(in_channels=8, out_channels=8, kernel_size=1, device=device, dtype=dtype),
|
||||
Slicing(dim=1, start=0, length=4),
|
||||
Slicing(dim=1, end=4),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
@ -264,11 +264,13 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
|||
InjectionPoint(), # Wq
|
||||
fl.Parallel(
|
||||
fl.Chain(
|
||||
fl.Slicing(dim=1, start=0, length=text_sequence_length),
|
||||
fl.Slicing(dim=1, end=text_sequence_length),
|
||||
InjectionPoint(), # Wk
|
||||
),
|
||||
fl.Chain(
|
||||
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
|
||||
fl.Slicing(
|
||||
dim=1, start=text_sequence_length, end=text_sequence_length + image_sequence_length
|
||||
),
|
||||
fl.Linear(
|
||||
in_features=self.target.key_embedding_dim,
|
||||
out_features=self.target.inner_dim,
|
||||
|
@ -280,11 +282,13 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
|||
),
|
||||
fl.Parallel(
|
||||
fl.Chain(
|
||||
fl.Slicing(dim=1, start=0, length=text_sequence_length),
|
||||
fl.Slicing(dim=1, end=text_sequence_length),
|
||||
InjectionPoint(), # Wv
|
||||
),
|
||||
fl.Chain(
|
||||
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
|
||||
fl.Slicing(
|
||||
dim=1, start=text_sequence_length, end=text_sequence_length + image_sequence_length
|
||||
),
|
||||
fl.Linear(
|
||||
in_features=self.target.key_embedding_dim,
|
||||
out_features=self.target.inner_dim,
|
||||
|
|
|
@ -86,7 +86,7 @@ class Controlnet(Passthrough):
|
|||
self.scale = scale
|
||||
super().__init__(
|
||||
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
|
||||
Slicing(dim=1, start=0, length=4), # support inpainting
|
||||
Slicing(dim=1, end=4), # support inpainting
|
||||
DownBlocks(in_channels=4, device=device, dtype=dtype),
|
||||
MiddleBlock(device=device, dtype=dtype),
|
||||
)
|
||||
|
|
|
@ -60,7 +60,7 @@ class Hypernetworks(fl.Concatenate):
|
|||
super().__init__(
|
||||
*[
|
||||
fl.Chain(
|
||||
fl.Slicing(dim=1, start=i + 1, length=1),
|
||||
fl.Slicing(dim=1, start=i + 1, end=i + 2),
|
||||
fl.MultiLinear(
|
||||
input_dim=embedding_dim,
|
||||
output_dim=embedding_dim // 8,
|
||||
|
@ -156,7 +156,7 @@ class MaskPrediction(fl.Chain):
|
|||
),
|
||||
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||
),
|
||||
fl.Slicing(dim=1, start=1, length=num_mask_tokens),
|
||||
fl.Slicing(dim=1, start=1, end=num_mask_tokens + 1),
|
||||
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim),
|
||||
)
|
||||
|
||||
|
@ -173,7 +173,7 @@ class IOUPrediction(fl.Chain):
|
|||
self.embedding_dim = embedding_dim
|
||||
self.num_layers = num_layers
|
||||
super().__init__(
|
||||
fl.Slicing(dim=1, start=0, length=1),
|
||||
fl.Slicing(dim=1, start=0, end=1),
|
||||
fl.Squeeze(dim=0),
|
||||
fl.MultiLinear(
|
||||
input_dim=embedding_dim,
|
||||
|
@ -183,7 +183,7 @@ class IOUPrediction(fl.Chain):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Slicing(dim=-1, start=1, length=num_mask_tokens),
|
||||
fl.Slicing(dim=-1, start=1, end=num_mask_tokens + 1),
|
||||
)
|
||||
|
||||
|
||||
|
|
67
tests/fluxion/layers/test_basics.py
Normal file
67
tests/fluxion/layers/test_basics.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
import torch
|
||||
|
||||
from refiners.fluxion.layers.basics import Slicing
|
||||
|
||||
|
||||
def test_slicing_positive_indices() -> None:
|
||||
x = torch.randn(5, 5, 5)
|
||||
slicing_layer = Slicing(dim=0, start=2, end=5)
|
||||
sliced = slicing_layer(x)
|
||||
expected = x[2:5, :]
|
||||
assert torch.equal(sliced, expected)
|
||||
|
||||
|
||||
def test_slicing_negative_indices() -> None:
|
||||
x = torch.randn(5, 5, 5)
|
||||
slicing_layer = Slicing(dim=1, start=-3, end=-1)
|
||||
sliced = slicing_layer(x)
|
||||
expected = x[:, -3:-1]
|
||||
assert torch.equal(sliced, expected)
|
||||
|
||||
|
||||
def test_slicing_step() -> None:
|
||||
x = torch.randn(5, 5, 5)
|
||||
slicing_layer = Slicing(dim=1, start=0, end=5, step=2)
|
||||
sliced = slicing_layer(x)
|
||||
expected = x[:, 0:5:2]
|
||||
assert torch.equal(sliced, expected)
|
||||
|
||||
|
||||
def test_slicing_empty_slice():
|
||||
x = torch.randn(5, 5, 5)
|
||||
slicing_layer = Slicing(dim=1, start=3, end=3)
|
||||
sliced = slicing_layer(x)
|
||||
expected = x[:, 3:3]
|
||||
assert torch.equal(sliced, expected)
|
||||
|
||||
|
||||
def test_slicing_full_dimension():
|
||||
x = torch.randn(5, 5, 5)
|
||||
slicing_layer = Slicing(dim=2, start=0, end=5)
|
||||
sliced = slicing_layer(x)
|
||||
expected = x[:, :, :]
|
||||
assert torch.equal(sliced, expected)
|
||||
|
||||
|
||||
def test_slicing_step_greater_than_range():
|
||||
x = torch.randn(5, 5, 5)
|
||||
slicing_layer = Slicing(dim=1, start=1, end=3, step=4)
|
||||
sliced = slicing_layer(x)
|
||||
expected = x[:, 1:3:4]
|
||||
assert torch.equal(sliced, expected)
|
||||
|
||||
|
||||
def test_slicing_reversed_start_end():
|
||||
x = torch.randn(5, 5, 5)
|
||||
slicing_layer = Slicing(dim=1, start=4, end=2)
|
||||
sliced = slicing_layer(x)
|
||||
expected = x[:, 4:2]
|
||||
assert torch.equal(sliced, expected)
|
||||
|
||||
|
||||
def test_slicing_out_of_bounds_indices():
|
||||
x = torch.randn(5, 5, 5)
|
||||
slicing_layer = Slicing(dim=1, start=-10, end=10)
|
||||
sliced = slicing_layer(x)
|
||||
expected = x[:, -10:10]
|
||||
assert torch.equal(sliced, expected)
|
Loading…
Reference in a new issue