refiners/tests/fluxion/layers/test_basics.py
2023-12-13 09:38:13 +01:00

68 lines
1.8 KiB
Python

import torch
from refiners.fluxion.layers.basics import Slicing
def test_slicing_positive_indices() -> None:
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=0, start=2, end=5)
sliced = slicing_layer(x)
expected = x[2:5, :]
assert torch.equal(sliced, expected)
def test_slicing_negative_indices() -> None:
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=-3, end=-1)
sliced = slicing_layer(x)
expected = x[:, -3:-1]
assert torch.equal(sliced, expected)
def test_slicing_step() -> None:
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=0, end=5, step=2)
sliced = slicing_layer(x)
expected = x[:, 0:5:2]
assert torch.equal(sliced, expected)
def test_slicing_empty_slice():
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=3, end=3)
sliced = slicing_layer(x)
expected = x[:, 3:3]
assert torch.equal(sliced, expected)
def test_slicing_full_dimension():
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=2, start=0, end=5)
sliced = slicing_layer(x)
expected = x[:, :, :]
assert torch.equal(sliced, expected)
def test_slicing_step_greater_than_range():
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=1, end=3, step=4)
sliced = slicing_layer(x)
expected = x[:, 1:3:4]
assert torch.equal(sliced, expected)
def test_slicing_reversed_start_end():
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=4, end=2)
sliced = slicing_layer(x)
expected = x[:, 4:2]
assert torch.equal(sliced, expected)
def test_slicing_out_of_bounds_indices():
x = torch.randn(5, 5, 5)
slicing_layer = Slicing(dim=1, start=-10, end=10)
sliced = slicing_layer(x)
expected = x[:, -10:10]
assert torch.equal(sliced, expected)