mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
Change fl.Slicing API
This commit is contained in:
parent
11b0ff6f8c
commit
a7551e0392
|
@ -84,14 +84,31 @@ class Permute(Module):
|
||||||
|
|
||||||
|
|
||||||
class Slicing(Module):
|
class Slicing(Module):
|
||||||
def __init__(self, dim: int, start: int, length: int) -> None:
|
def __init__(self, dim: int = 0, start: int = 0, end: int = -1, step: int = 1) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.start = start
|
self.start = start
|
||||||
self.length = length
|
self.end = end
|
||||||
|
self.step = step
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x.narrow(self.dim, self.start, self.length)
|
dim_size = x.shape[self.dim]
|
||||||
|
start = self.start if self.start >= 0 else dim_size + self.start
|
||||||
|
end = self.end if self.end >= 0 else dim_size + self.end
|
||||||
|
start = max(min(start, dim_size), 0)
|
||||||
|
end = max(min(end, dim_size), 0)
|
||||||
|
if start >= end:
|
||||||
|
return self.get_empty_slice(x)
|
||||||
|
indices = torch.arange(start=start, end=end, step=self.step, device=x.device)
|
||||||
|
return x.index_select(self.dim, indices)
|
||||||
|
|
||||||
|
def get_empty_slice(self, x: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Return an empty slice of the same shape as the input tensor to mimic PyTorch's slicing behavior.
|
||||||
|
"""
|
||||||
|
shape = list(x.shape)
|
||||||
|
shape[self.dim] = 0
|
||||||
|
return torch.empty(*shape, device=x.device)
|
||||||
|
|
||||||
|
|
||||||
class Squeeze(Module):
|
class Squeeze(Module):
|
||||||
|
|
|
@ -113,7 +113,7 @@ class Encoder(Chain):
|
||||||
),
|
),
|
||||||
Chain(
|
Chain(
|
||||||
Conv2d(in_channels=8, out_channels=8, kernel_size=1, device=device, dtype=dtype),
|
Conv2d(in_channels=8, out_channels=8, kernel_size=1, device=device, dtype=dtype),
|
||||||
Slicing(dim=1, start=0, length=4),
|
Slicing(dim=1, end=4),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -264,11 +264,13 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
InjectionPoint(), # Wq
|
InjectionPoint(), # Wq
|
||||||
fl.Parallel(
|
fl.Parallel(
|
||||||
fl.Chain(
|
fl.Chain(
|
||||||
fl.Slicing(dim=1, start=0, length=text_sequence_length),
|
fl.Slicing(dim=1, end=text_sequence_length),
|
||||||
InjectionPoint(), # Wk
|
InjectionPoint(), # Wk
|
||||||
),
|
),
|
||||||
fl.Chain(
|
fl.Chain(
|
||||||
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
|
fl.Slicing(
|
||||||
|
dim=1, start=text_sequence_length, end=text_sequence_length + image_sequence_length
|
||||||
|
),
|
||||||
fl.Linear(
|
fl.Linear(
|
||||||
in_features=self.target.key_embedding_dim,
|
in_features=self.target.key_embedding_dim,
|
||||||
out_features=self.target.inner_dim,
|
out_features=self.target.inner_dim,
|
||||||
|
@ -280,11 +282,13 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
),
|
),
|
||||||
fl.Parallel(
|
fl.Parallel(
|
||||||
fl.Chain(
|
fl.Chain(
|
||||||
fl.Slicing(dim=1, start=0, length=text_sequence_length),
|
fl.Slicing(dim=1, end=text_sequence_length),
|
||||||
InjectionPoint(), # Wv
|
InjectionPoint(), # Wv
|
||||||
),
|
),
|
||||||
fl.Chain(
|
fl.Chain(
|
||||||
fl.Slicing(dim=1, start=text_sequence_length, length=image_sequence_length),
|
fl.Slicing(
|
||||||
|
dim=1, start=text_sequence_length, end=text_sequence_length + image_sequence_length
|
||||||
|
),
|
||||||
fl.Linear(
|
fl.Linear(
|
||||||
in_features=self.target.key_embedding_dim,
|
in_features=self.target.key_embedding_dim,
|
||||||
out_features=self.target.inner_dim,
|
out_features=self.target.inner_dim,
|
||||||
|
|
|
@ -86,7 +86,7 @@ class Controlnet(Passthrough):
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
super().__init__(
|
super().__init__(
|
||||||
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
|
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
|
||||||
Slicing(dim=1, start=0, length=4), # support inpainting
|
Slicing(dim=1, end=4), # support inpainting
|
||||||
DownBlocks(in_channels=4, device=device, dtype=dtype),
|
DownBlocks(in_channels=4, device=device, dtype=dtype),
|
||||||
MiddleBlock(device=device, dtype=dtype),
|
MiddleBlock(device=device, dtype=dtype),
|
||||||
)
|
)
|
||||||
|
|
|
@ -60,7 +60,7 @@ class Hypernetworks(fl.Concatenate):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
*[
|
*[
|
||||||
fl.Chain(
|
fl.Chain(
|
||||||
fl.Slicing(dim=1, start=i + 1, length=1),
|
fl.Slicing(dim=1, start=i + 1, end=i + 2),
|
||||||
fl.MultiLinear(
|
fl.MultiLinear(
|
||||||
input_dim=embedding_dim,
|
input_dim=embedding_dim,
|
||||||
output_dim=embedding_dim // 8,
|
output_dim=embedding_dim // 8,
|
||||||
|
@ -156,7 +156,7 @@ class MaskPrediction(fl.Chain):
|
||||||
),
|
),
|
||||||
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||||
),
|
),
|
||||||
fl.Slicing(dim=1, start=1, length=num_mask_tokens),
|
fl.Slicing(dim=1, start=1, end=num_mask_tokens + 1),
|
||||||
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim),
|
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -173,7 +173,7 @@ class IOUPrediction(fl.Chain):
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Slicing(dim=1, start=0, length=1),
|
fl.Slicing(dim=1, start=0, end=1),
|
||||||
fl.Squeeze(dim=0),
|
fl.Squeeze(dim=0),
|
||||||
fl.MultiLinear(
|
fl.MultiLinear(
|
||||||
input_dim=embedding_dim,
|
input_dim=embedding_dim,
|
||||||
|
@ -183,7 +183,7 @@ class IOUPrediction(fl.Chain):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
fl.Slicing(dim=-1, start=1, length=num_mask_tokens),
|
fl.Slicing(dim=-1, start=1, end=num_mask_tokens + 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
67
tests/fluxion/layers/test_basics.py
Normal file
67
tests/fluxion/layers/test_basics.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from refiners.fluxion.layers.basics import Slicing
|
||||||
|
|
||||||
|
|
||||||
|
def test_slicing_positive_indices() -> None:
|
||||||
|
x = torch.randn(5, 5, 5)
|
||||||
|
slicing_layer = Slicing(dim=0, start=2, end=5)
|
||||||
|
sliced = slicing_layer(x)
|
||||||
|
expected = x[2:5, :]
|
||||||
|
assert torch.equal(sliced, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_slicing_negative_indices() -> None:
|
||||||
|
x = torch.randn(5, 5, 5)
|
||||||
|
slicing_layer = Slicing(dim=1, start=-3, end=-1)
|
||||||
|
sliced = slicing_layer(x)
|
||||||
|
expected = x[:, -3:-1]
|
||||||
|
assert torch.equal(sliced, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_slicing_step() -> None:
|
||||||
|
x = torch.randn(5, 5, 5)
|
||||||
|
slicing_layer = Slicing(dim=1, start=0, end=5, step=2)
|
||||||
|
sliced = slicing_layer(x)
|
||||||
|
expected = x[:, 0:5:2]
|
||||||
|
assert torch.equal(sliced, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_slicing_empty_slice():
|
||||||
|
x = torch.randn(5, 5, 5)
|
||||||
|
slicing_layer = Slicing(dim=1, start=3, end=3)
|
||||||
|
sliced = slicing_layer(x)
|
||||||
|
expected = x[:, 3:3]
|
||||||
|
assert torch.equal(sliced, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_slicing_full_dimension():
|
||||||
|
x = torch.randn(5, 5, 5)
|
||||||
|
slicing_layer = Slicing(dim=2, start=0, end=5)
|
||||||
|
sliced = slicing_layer(x)
|
||||||
|
expected = x[:, :, :]
|
||||||
|
assert torch.equal(sliced, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_slicing_step_greater_than_range():
|
||||||
|
x = torch.randn(5, 5, 5)
|
||||||
|
slicing_layer = Slicing(dim=1, start=1, end=3, step=4)
|
||||||
|
sliced = slicing_layer(x)
|
||||||
|
expected = x[:, 1:3:4]
|
||||||
|
assert torch.equal(sliced, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_slicing_reversed_start_end():
|
||||||
|
x = torch.randn(5, 5, 5)
|
||||||
|
slicing_layer = Slicing(dim=1, start=4, end=2)
|
||||||
|
sliced = slicing_layer(x)
|
||||||
|
expected = x[:, 4:2]
|
||||||
|
assert torch.equal(sliced, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_slicing_out_of_bounds_indices():
|
||||||
|
x = torch.randn(5, 5, 5)
|
||||||
|
slicing_layer = Slicing(dim=1, start=-10, end=10)
|
||||||
|
sliced = slicing_layer(x)
|
||||||
|
expected = x[:, -10:10]
|
||||||
|
assert torch.equal(sliced, expected)
|
Loading…
Reference in a new issue