refiners/tests/training_utils/test_common.py

105 lines
2.8 KiB
Python
Raw Normal View History

import random
2024-03-19 13:20:40 +00:00
import pytest
import torch
2024-03-19 13:20:40 +00:00
from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field, scoped_seed
2024-03-19 13:20:40 +00:00
@pytest.mark.parametrize(
"input_value, expected_output",
[
("10: step", TimeValue(number=10, unit=TimeUnit.STEP)),
("20 :epoch", TimeValue(number=20, unit=TimeUnit.EPOCH)),
("30: Iteration", TimeValue(number=30, unit=TimeUnit.ITERATION)),
(50, TimeValue(number=50, unit=TimeUnit.DEFAULT)),
({"number": 100, "unit": "STEP"}, TimeValue(number=100, unit=TimeUnit.STEP)),
(TimeValue(number=200, unit=TimeUnit.EPOCH), TimeValue(number=200, unit=TimeUnit.EPOCH)),
],
)
def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: TimeValue):
result = parse_number_unit_field(input_value)
assert result == expected_output
@pytest.mark.parametrize(
"invalid_input",
[
"invalid:input",
{"number": "not_a_number", "unit": "step"},
{"invalid_key": 10},
None,
],
)
def test_parse_number_unit_field_invalid_input(invalid_input: TimeValueInput):
with pytest.raises(ValueError):
parse_number_unit_field(invalid_input)
@scoped_seed(seed=37)
def pick_a_number() -> int:
return int(torch.randint(0, 100, (1,)).item())
@pytest.mark.parametrize(
"seed, expected_output",
[
(42, 42),
(37, 31),
(0, 44),
],
)
def test_scoped_seed_with_specific_seed(seed: int, expected_output: int) -> None:
with scoped_seed(seed):
assert torch.randint(0, 100, (1,)).item() == expected_output
@pytest.mark.parametrize(
"seed, expected_output",
[
(42, 81),
(37, 87),
(0, 49),
],
)
def test_scoped_seed_with_random_module(seed: int, expected_output: int) -> None:
with scoped_seed(seed):
assert random.randint(0, 100) == expected_output
def test_scoped_seed_with_function_call() -> None:
assert pick_a_number() == 31
with scoped_seed(37):
assert pick_a_number() == 31
def test_scoped_seed_with_callable_seed() -> None:
with scoped_seed(pick_a_number):
assert pick_a_number() == 31
def add_10(n: int) -> int:
return n + 10
@scoped_seed(seed=add_10)
def pick_a_number_greater_than_n_plus_10(n: int) -> int:
return int(torch.randint(n, 100, (1,)).item())
assert pick_a_number_greater_than_n_plus_10(10) == 81
def test_scoped_seed_restore_state() -> None:
random.seed(37)
with scoped_seed(42):
random.randint(0, 100)
assert random.randint(0, 100) == 87
2024-04-17 15:17:47 +00:00
def test_import_training_utils() -> None:
try:
import refiners.training_utils
except ImportError:
pytest.fail("Failed to import refiners.training_utils")
assert refiners.training_utils is not None