refiners/tests/training_utils/test_common.py
2024-03-19 14:49:24 +01:00

34 lines
1.1 KiB
Python

import pytest
from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field
@pytest.mark.parametrize(
"input_value, expected_output",
[
("10: step", TimeValue(number=10, unit=TimeUnit.STEP)),
("20 :epoch", TimeValue(number=20, unit=TimeUnit.EPOCH)),
("30: Iteration", TimeValue(number=30, unit=TimeUnit.ITERATION)),
(50, TimeValue(number=50, unit=TimeUnit.DEFAULT)),
({"number": 100, "unit": "STEP"}, TimeValue(number=100, unit=TimeUnit.STEP)),
(TimeValue(number=200, unit=TimeUnit.EPOCH), TimeValue(number=200, unit=TimeUnit.EPOCH)),
],
)
def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: TimeValue):
result = parse_number_unit_field(input_value)
assert result == expected_output
@pytest.mark.parametrize(
"invalid_input",
[
"invalid:input",
{"number": "not_a_number", "unit": "step"},
{"invalid_key": 10},
None,
],
)
def test_parse_number_unit_field_invalid_input(invalid_input: TimeValueInput):
with pytest.raises(ValueError):
parse_number_unit_field(invalid_input)