mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
|
import pytest
|
||
|
|
||
|
from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"input_value, expected_output",
|
||
|
[
|
||
|
("10: step", TimeValue(number=10, unit=TimeUnit.STEP)),
|
||
|
("20 :epoch", TimeValue(number=20, unit=TimeUnit.EPOCH)),
|
||
|
("30: Iteration", TimeValue(number=30, unit=TimeUnit.ITERATION)),
|
||
|
(50, TimeValue(number=50, unit=TimeUnit.DEFAULT)),
|
||
|
({"number": 100, "unit": "STEP"}, TimeValue(number=100, unit=TimeUnit.STEP)),
|
||
|
(TimeValue(number=200, unit=TimeUnit.EPOCH), TimeValue(number=200, unit=TimeUnit.EPOCH)),
|
||
|
],
|
||
|
)
|
||
|
def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: TimeValue):
|
||
|
result = parse_number_unit_field(input_value)
|
||
|
assert result == expected_output
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"invalid_input",
|
||
|
[
|
||
|
"invalid:input",
|
||
|
{"number": "not_a_number", "unit": "step"},
|
||
|
{"invalid_key": 10},
|
||
|
None,
|
||
|
],
|
||
|
)
|
||
|
def test_parse_number_unit_field_invalid_input(invalid_input: TimeValueInput):
|
||
|
with pytest.raises(ValueError):
|
||
|
parse_number_unit_field(invalid_input)
|