mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
change default behavior of end to None
This commit is contained in:
parent
82a2aa1ec4
commit
7d9ceae274
|
@ -84,7 +84,7 @@ class Permute(Module):
|
||||||
|
|
||||||
|
|
||||||
class Slicing(Module):
|
class Slicing(Module):
|
||||||
def __init__(self, dim: int = 0, start: int = 0, end: int = -1, step: int = 1) -> None:
|
def __init__(self, dim: int = 0, start: int = 0, end: int | None = None, step: int = 1) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.start = start
|
self.start = start
|
||||||
|
@ -94,7 +94,8 @@ class Slicing(Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
dim_size = x.shape[self.dim]
|
dim_size = x.shape[self.dim]
|
||||||
start = self.start if self.start >= 0 else dim_size + self.start
|
start = self.start if self.start >= 0 else dim_size + self.start
|
||||||
end = self.end if self.end >= 0 else dim_size + self.end
|
end = self.end or dim_size
|
||||||
|
end = end if end >= 0 else dim_size + end
|
||||||
start = max(min(start, dim_size), 0)
|
start = max(min(start, dim_size), 0)
|
||||||
end = max(min(end, dim_size), 0)
|
end = max(min(end, dim_size), 0)
|
||||||
if start >= end:
|
if start >= end:
|
||||||
|
|
|
@ -268,9 +268,7 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
InjectionPoint(), # Wk
|
InjectionPoint(), # Wk
|
||||||
),
|
),
|
||||||
fl.Chain(
|
fl.Chain(
|
||||||
fl.Slicing(
|
fl.Slicing(dim=1, start=text_sequence_length),
|
||||||
dim=1, start=text_sequence_length, end=text_sequence_length + image_sequence_length
|
|
||||||
),
|
|
||||||
fl.Linear(
|
fl.Linear(
|
||||||
in_features=self.target.key_embedding_dim,
|
in_features=self.target.key_embedding_dim,
|
||||||
out_features=self.target.inner_dim,
|
out_features=self.target.inner_dim,
|
||||||
|
@ -286,9 +284,7 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
InjectionPoint(), # Wv
|
InjectionPoint(), # Wv
|
||||||
),
|
),
|
||||||
fl.Chain(
|
fl.Chain(
|
||||||
fl.Slicing(
|
fl.Slicing(dim=1, start=text_sequence_length),
|
||||||
dim=1, start=text_sequence_length, end=text_sequence_length + image_sequence_length
|
|
||||||
),
|
|
||||||
fl.Linear(
|
fl.Linear(
|
||||||
in_features=self.target.key_embedding_dim,
|
in_features=self.target.key_embedding_dim,
|
||||||
out_features=self.target.inner_dim,
|
out_features=self.target.inner_dim,
|
||||||
|
|
|
@ -156,7 +156,7 @@ class MaskPrediction(fl.Chain):
|
||||||
),
|
),
|
||||||
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||||
),
|
),
|
||||||
fl.Slicing(dim=1, start=1, end=num_mask_tokens + 1),
|
fl.Slicing(dim=1, start=1),
|
||||||
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim),
|
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -183,7 +183,7 @@ class IOUPrediction(fl.Chain):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
fl.Slicing(dim=-1, start=1, end=num_mask_tokens + 1),
|
fl.Slicing(dim=-1, start=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,14 @@ def test_slicing_negative_indices() -> None:
|
||||||
assert torch.equal(sliced, expected)
|
assert torch.equal(sliced, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_none_end_slicing() -> None:
|
||||||
|
x = torch.randn(2, 1000, 400)
|
||||||
|
slicing = Slicing(dim=1, start=1)
|
||||||
|
sliced = slicing(x)
|
||||||
|
expected = x[:, 1:, :]
|
||||||
|
assert torch.equal(sliced, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_slicing_step() -> None:
|
def test_slicing_step() -> None:
|
||||||
x = torch.randn(5, 5, 5)
|
x = torch.randn(5, 5, 5)
|
||||||
slicing_layer = Slicing(dim=1, start=0, end=5, step=2)
|
slicing_layer = Slicing(dim=1, start=0, end=5, step=2)
|
||||||
|
@ -27,7 +35,7 @@ def test_slicing_step() -> None:
|
||||||
assert torch.equal(sliced, expected)
|
assert torch.equal(sliced, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_slicing_empty_slice():
|
def test_slicing_empty_slice() -> None:
|
||||||
x = torch.randn(5, 5, 5)
|
x = torch.randn(5, 5, 5)
|
||||||
slicing_layer = Slicing(dim=1, start=3, end=3)
|
slicing_layer = Slicing(dim=1, start=3, end=3)
|
||||||
sliced = slicing_layer(x)
|
sliced = slicing_layer(x)
|
||||||
|
@ -35,7 +43,7 @@ def test_slicing_empty_slice():
|
||||||
assert torch.equal(sliced, expected)
|
assert torch.equal(sliced, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_slicing_full_dimension():
|
def test_slicing_full_dimension() -> None:
|
||||||
x = torch.randn(5, 5, 5)
|
x = torch.randn(5, 5, 5)
|
||||||
slicing_layer = Slicing(dim=2, start=0, end=5)
|
slicing_layer = Slicing(dim=2, start=0, end=5)
|
||||||
sliced = slicing_layer(x)
|
sliced = slicing_layer(x)
|
||||||
|
@ -43,7 +51,7 @@ def test_slicing_full_dimension():
|
||||||
assert torch.equal(sliced, expected)
|
assert torch.equal(sliced, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_slicing_step_greater_than_range():
|
def test_slicing_step_greater_than_range() -> None:
|
||||||
x = torch.randn(5, 5, 5)
|
x = torch.randn(5, 5, 5)
|
||||||
slicing_layer = Slicing(dim=1, start=1, end=3, step=4)
|
slicing_layer = Slicing(dim=1, start=1, end=3, step=4)
|
||||||
sliced = slicing_layer(x)
|
sliced = slicing_layer(x)
|
||||||
|
@ -51,7 +59,7 @@ def test_slicing_step_greater_than_range():
|
||||||
assert torch.equal(sliced, expected)
|
assert torch.equal(sliced, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_slicing_reversed_start_end():
|
def test_slicing_reversed_start_end() -> None:
|
||||||
x = torch.randn(5, 5, 5)
|
x = torch.randn(5, 5, 5)
|
||||||
slicing_layer = Slicing(dim=1, start=4, end=2)
|
slicing_layer = Slicing(dim=1, start=4, end=2)
|
||||||
sliced = slicing_layer(x)
|
sliced = slicing_layer(x)
|
||||||
|
@ -59,7 +67,7 @@ def test_slicing_reversed_start_end():
|
||||||
assert torch.equal(sliced, expected)
|
assert torch.equal(sliced, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_slicing_out_of_bounds_indices():
|
def test_slicing_out_of_bounds_indices() -> None:
|
||||||
x = torch.randn(5, 5, 5)
|
x = torch.randn(5, 5, 5)
|
||||||
slicing_layer = Slicing(dim=1, start=-10, end=10)
|
slicing_layer = Slicing(dim=1, start=-10, end=10)
|
||||||
sliced = slicing_layer(x)
|
sliced = slicing_layer(x)
|
||||||
|
|
Loading…
Reference in a new issue