refiners/tests/fluxion/layers/test_basics.py

76 lines
2 KiB
Python
Raw Permalink Normal View History

2023-12-12 19:09:58 +00:00
import torch
from refiners.fluxion.layers.basics import Slicing
def test_slicing_positive_indices() -> None:
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=0, start=2, end=5)
sliced = slicing_layer(x)
expected = x[2:5, :]
assert torch.equal(sliced, expected)
def test_slicing_negative_indices() -> None:
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=-3, end=-1)
sliced = slicing_layer(x)
expected = x[:, -3:-1]
assert torch.equal(sliced, expected)
2023-12-13 13:51:18 +00:00
def test_none_end_slicing() -> None:
x = torch.randn(2, 1000, 400)
slicing_layer = Slicing(dim=1, start=1)
sliced = slicing_layer(x)
2023-12-13 13:51:18 +00:00
expected = x[:, 1:, :]
assert torch.equal(sliced, expected)
2023-12-12 19:09:58 +00:00
def test_slicing_step() -> None:
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=0, end=5, step=2)
sliced = slicing_layer(x)
expected = x[:, 0:5:2]
assert torch.equal(sliced, expected)
2023-12-13 13:51:18 +00:00
def test_slicing_empty_slice() -> None:
2023-12-12 19:09:58 +00:00
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=3, end=3)
sliced = slicing_layer(x)
expected = x[:, 3:3]
assert torch.equal(sliced, expected)
2023-12-13 13:51:18 +00:00
def test_slicing_full_dimension() -> None:
2023-12-12 19:09:58 +00:00
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=2, start=0, end=5)
sliced = slicing_layer(x)
expected = x[:, :, :]
assert torch.equal(sliced, expected)
2023-12-13 13:51:18 +00:00
def test_slicing_step_greater_than_range() -> None:
2023-12-12 19:09:58 +00:00
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=1, end=3, step=4)
sliced = slicing_layer(x)
expected = x[:, 1:3:4]
assert torch.equal(sliced, expected)
2023-12-13 13:51:18 +00:00
def test_slicing_reversed_start_end() -> None:
2023-12-12 19:09:58 +00:00
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=4, end=2)
sliced = slicing_layer(x)
expected = x[:, 4:2]
assert torch.equal(sliced, expected)
2023-12-13 13:51:18 +00:00
def test_slicing_out_of_bounds_indices() -> None:
2023-12-12 19:09:58 +00:00
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=-10, end=10)
sliced = slicing_layer(x)
expected = x[:, -10:10]
assert torch.equal(sliced, expected)