mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
turn scoped_seed into a context manager
This commit is contained in:
parent
64692c3b5b
commit
b9b999ccfe
|
@ -1,7 +1,6 @@
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import wraps
|
|
||||||
from typing import Any, Callable, Iterable
|
from typing import Any, Callable, Iterable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -44,38 +43,44 @@ def seed_everything(seed: int | None = None) -> None:
|
||||||
cuda.manual_seed_all(seed=seed)
|
cuda.manual_seed_all(seed=seed)
|
||||||
|
|
||||||
|
|
||||||
def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., Callable[..., Any]]:
|
class scoped_seed:
|
||||||
"""
|
"""
|
||||||
Decorator for setting a random seed within the scope of a function.
|
Context manager and decorator to set a fixed seed within a specific scope.
|
||||||
|
|
||||||
This decorator sets the random seed for Python's built-in `random` module,
|
The seed can be provided directly or as a callable that takes the same arguments
|
||||||
`numpy`, and `torch` and `torch.cuda` at the beginning of the decorated function. After the
|
as the decorated function. Supports setting the seed for random, numpy, torch,
|
||||||
function is executed, it restores the state of the random number generators
|
and torch.cuda modules.
|
||||||
to what it was before the function was called. This is useful for ensuring
|
|
||||||
reproducibility for specific parts of the code without affecting randomness
|
|
||||||
elsewhere.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
def __init__(self, seed: int | Callable[..., int] | None = None):
|
||||||
@wraps(func)
|
self.seed = seed
|
||||||
|
self.actual_seed: int | None = None
|
||||||
|
|
||||||
|
def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
|
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
random_state = random.getstate()
|
self.actual_seed = self.seed(*args, **kwargs) if callable(self.seed) else self.seed
|
||||||
numpy_state = np.random.get_state()
|
with self:
|
||||||
torch_state = torch.get_rng_state()
|
return func(*args, **kwargs)
|
||||||
cuda_torch_state = cuda.get_rng_state()
|
|
||||||
actual_seed = seed(*args) if callable(seed) else seed
|
|
||||||
seed_everything(seed=actual_seed)
|
|
||||||
result = func(*args, **kwargs)
|
|
||||||
logger.trace(f"Restoring previous seed state")
|
|
||||||
random.setstate(random_state)
|
|
||||||
np.random.set_state(numpy_state)
|
|
||||||
torch.set_rng_state(torch_state)
|
|
||||||
cuda.set_rng_state(cuda_torch_state)
|
|
||||||
return result
|
|
||||||
|
|
||||||
return inner_wrapper
|
return inner_wrapper
|
||||||
|
|
||||||
return decorator
|
def __enter__(self) -> None:
|
||||||
|
if self.actual_seed is None:
|
||||||
|
seed = self.seed() if callable(self.seed) else self.seed
|
||||||
|
else:
|
||||||
|
seed = self.actual_seed
|
||||||
|
self.random_state = random.getstate()
|
||||||
|
self.numpy_state = np.random.get_state()
|
||||||
|
self.torch_state = torch.get_rng_state()
|
||||||
|
self.cuda_torch_state = cuda.get_rng_state()
|
||||||
|
seed_everything(seed)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||||
|
logger.trace(f"Restoring previous seed state")
|
||||||
|
random.setstate(self.random_state)
|
||||||
|
np.random.set_state(self.numpy_state)
|
||||||
|
torch.set_rng_state(self.torch_state)
|
||||||
|
cuda.set_rng_state(self.cuda_torch_state)
|
||||||
|
|
||||||
|
|
||||||
class TimeUnit(str, Enum):
|
class TimeUnit(str, Enum):
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
import pytest
|
import random
|
||||||
|
|
||||||
from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field, scoped_seed
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -31,3 +34,62 @@ def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: T
|
||||||
def test_parse_number_unit_field_invalid_input(invalid_input: TimeValueInput):
|
def test_parse_number_unit_field_invalid_input(invalid_input: TimeValueInput):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
parse_number_unit_field(invalid_input)
|
parse_number_unit_field(invalid_input)
|
||||||
|
|
||||||
|
|
||||||
|
@scoped_seed(seed=37)
|
||||||
|
def pick_a_number() -> int:
|
||||||
|
return int(torch.randint(0, 100, (1,)).item())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"seed, expected_output",
|
||||||
|
[
|
||||||
|
(42, 42),
|
||||||
|
(37, 31),
|
||||||
|
(0, 44),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_scoped_seed_with_specific_seed(seed: int, expected_output: int) -> None:
|
||||||
|
with scoped_seed(seed):
|
||||||
|
assert torch.randint(0, 100, (1,)).item() == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"seed, expected_output",
|
||||||
|
[
|
||||||
|
(42, 81),
|
||||||
|
(37, 87),
|
||||||
|
(0, 49),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_scoped_seed_with_random_module(seed: int, expected_output: int) -> None:
|
||||||
|
with scoped_seed(seed):
|
||||||
|
assert random.randint(0, 100) == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_scoped_seed_with_function_call() -> None:
|
||||||
|
assert pick_a_number() == 31
|
||||||
|
|
||||||
|
with scoped_seed(37):
|
||||||
|
assert pick_a_number() == 31
|
||||||
|
|
||||||
|
|
||||||
|
def test_scoped_seed_with_callable_seed() -> None:
|
||||||
|
with scoped_seed(pick_a_number):
|
||||||
|
assert pick_a_number() == 31
|
||||||
|
|
||||||
|
def add_10(n: int) -> int:
|
||||||
|
return n + 10
|
||||||
|
|
||||||
|
@scoped_seed(seed=add_10)
|
||||||
|
def pick_a_number_greater_than_n_plus_10(n: int) -> int:
|
||||||
|
return int(torch.randint(n, 100, (1,)).item())
|
||||||
|
|
||||||
|
assert pick_a_number_greater_than_n_plus_10(10) == 81
|
||||||
|
|
||||||
|
|
||||||
|
def test_scoped_seed_restore_state() -> None:
|
||||||
|
random.seed(37)
|
||||||
|
with scoped_seed(42):
|
||||||
|
random.randint(0, 100)
|
||||||
|
assert random.randint(0, 100) == 87
|
||||||
|
|
Loading…
Reference in a new issue