mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-10 07:21:59 +00:00
68 lines
1.8 KiB
Python
68 lines
1.8 KiB
Python
|
import torch
|
||
|
|
||
|
from refiners.fluxion.layers.basics import Slicing
|
||
|
|
||
|
|
||
|
def test_slicing_positive_indices() -> None:
|
||
|
x = torch.randn(5, 5, 5)
|
||
|
slicing_layer = Slicing(dim=0, start=2, end=5)
|
||
|
sliced = slicing_layer(x)
|
||
|
expected = x[2:5, :]
|
||
|
assert torch.equal(sliced, expected)
|
||
|
|
||
|
|
||
|
def test_slicing_negative_indices() -> None:
|
||
|
x = torch.randn(5, 5, 5)
|
||
|
slicing_layer = Slicing(dim=1, start=-3, end=-1)
|
||
|
sliced = slicing_layer(x)
|
||
|
expected = x[:, -3:-1]
|
||
|
assert torch.equal(sliced, expected)
|
||
|
|
||
|
|
||
|
def test_slicing_step() -> None:
|
||
|
x = torch.randn(5, 5, 5)
|
||
|
slicing_layer = Slicing(dim=1, start=0, end=5, step=2)
|
||
|
sliced = slicing_layer(x)
|
||
|
expected = x[:, 0:5:2]
|
||
|
assert torch.equal(sliced, expected)
|
||
|
|
||
|
|
||
|
def test_slicing_empty_slice():
|
||
|
x = torch.randn(5, 5, 5)
|
||
|
slicing_layer = Slicing(dim=1, start=3, end=3)
|
||
|
sliced = slicing_layer(x)
|
||
|
expected = x[:, 3:3]
|
||
|
assert torch.equal(sliced, expected)
|
||
|
|
||
|
|
||
|
def test_slicing_full_dimension():
|
||
|
x = torch.randn(5, 5, 5)
|
||
|
slicing_layer = Slicing(dim=2, start=0, end=5)
|
||
|
sliced = slicing_layer(x)
|
||
|
expected = x[:, :, :]
|
||
|
assert torch.equal(sliced, expected)
|
||
|
|
||
|
|
||
|
def test_slicing_step_greater_than_range():
|
||
|
x = torch.randn(5, 5, 5)
|
||
|
slicing_layer = Slicing(dim=1, start=1, end=3, step=4)
|
||
|
sliced = slicing_layer(x)
|
||
|
expected = x[:, 1:3:4]
|
||
|
assert torch.equal(sliced, expected)
|
||
|
|
||
|
|
||
|
def test_slicing_reversed_start_end():
|
||
|
x = torch.randn(5, 5, 5)
|
||
|
slicing_layer = Slicing(dim=1, start=4, end=2)
|
||
|
sliced = slicing_layer(x)
|
||
|
expected = x[:, 4:2]
|
||
|
assert torch.equal(sliced, expected)
|
||
|
|
||
|
|
||
|
def test_slicing_out_of_bounds_indices():
|
||
|
x = torch.randn(5, 5, 5)
|
||
|
slicing_layer = Slicing(dim=1, start=-10, end=10)
|
||
|
sliced = slicing_layer(x)
|
||
|
expected = x[:, -10:10]
|
||
|
assert torch.equal(sliced, expected)
|