mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
turn scoped_seed into a context manager
This commit is contained in:
parent
64692c3b5b
commit
b9b999ccfe
|
@ -1,7 +1,6 @@
|
|||
import random
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Iterable
|
||||
|
||||
import numpy as np
|
||||
|
@ -44,38 +43,44 @@ def seed_everything(seed: int | None = None) -> None:
|
|||
cuda.manual_seed_all(seed=seed)
|
||||
|
||||
|
||||
def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., Callable[..., Any]]:
|
||||
class scoped_seed:
|
||||
"""
|
||||
Decorator for setting a random seed within the scope of a function.
|
||||
Context manager and decorator to set a fixed seed within a specific scope.
|
||||
|
||||
This decorator sets the random seed for Python's built-in `random` module,
|
||||
`numpy`, and `torch` and `torch.cuda` at the beginning of the decorated function. After the
|
||||
function is executed, it restores the state of the random number generators
|
||||
to what it was before the function was called. This is useful for ensuring
|
||||
reproducibility for specific parts of the code without affecting randomness
|
||||
elsewhere.
|
||||
The seed can be provided directly or as a callable that takes the same arguments
|
||||
as the decorated function. Supports setting the seed for random, numpy, torch,
|
||||
and torch.cuda modules.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@wraps(func)
|
||||
def __init__(self, seed: int | Callable[..., int] | None = None):
|
||||
self.seed = seed
|
||||
self.actual_seed: int | None = None
|
||||
|
||||
def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
random_state = random.getstate()
|
||||
numpy_state = np.random.get_state()
|
||||
torch_state = torch.get_rng_state()
|
||||
cuda_torch_state = cuda.get_rng_state()
|
||||
actual_seed = seed(*args) if callable(seed) else seed
|
||||
seed_everything(seed=actual_seed)
|
||||
result = func(*args, **kwargs)
|
||||
logger.trace(f"Restoring previous seed state")
|
||||
random.setstate(random_state)
|
||||
np.random.set_state(numpy_state)
|
||||
torch.set_rng_state(torch_state)
|
||||
cuda.set_rng_state(cuda_torch_state)
|
||||
return result
|
||||
self.actual_seed = self.seed(*args, **kwargs) if callable(self.seed) else self.seed
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return inner_wrapper
|
||||
|
||||
return decorator
|
||||
def __enter__(self) -> None:
|
||||
if self.actual_seed is None:
|
||||
seed = self.seed() if callable(self.seed) else self.seed
|
||||
else:
|
||||
seed = self.actual_seed
|
||||
self.random_state = random.getstate()
|
||||
self.numpy_state = np.random.get_state()
|
||||
self.torch_state = torch.get_rng_state()
|
||||
self.cuda_torch_state = cuda.get_rng_state()
|
||||
seed_everything(seed)
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
logger.trace(f"Restoring previous seed state")
|
||||
random.setstate(self.random_state)
|
||||
np.random.set_state(self.numpy_state)
|
||||
torch.set_rng_state(self.torch_state)
|
||||
cuda.set_rng_state(self.cuda_torch_state)
|
||||
|
||||
|
||||
class TimeUnit(str, Enum):
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import pytest
|
||||
import random
|
||||
|
||||
from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field, scoped_seed
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -31,3 +34,62 @@ def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: T
|
|||
def test_parse_number_unit_field_invalid_input(invalid_input: TimeValueInput):
|
||||
with pytest.raises(ValueError):
|
||||
parse_number_unit_field(invalid_input)
|
||||
|
||||
|
||||
@scoped_seed(seed=37)
|
||||
def pick_a_number() -> int:
|
||||
return int(torch.randint(0, 100, (1,)).item())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seed, expected_output",
|
||||
[
|
||||
(42, 42),
|
||||
(37, 31),
|
||||
(0, 44),
|
||||
],
|
||||
)
|
||||
def test_scoped_seed_with_specific_seed(seed: int, expected_output: int) -> None:
|
||||
with scoped_seed(seed):
|
||||
assert torch.randint(0, 100, (1,)).item() == expected_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seed, expected_output",
|
||||
[
|
||||
(42, 81),
|
||||
(37, 87),
|
||||
(0, 49),
|
||||
],
|
||||
)
|
||||
def test_scoped_seed_with_random_module(seed: int, expected_output: int) -> None:
|
||||
with scoped_seed(seed):
|
||||
assert random.randint(0, 100) == expected_output
|
||||
|
||||
|
||||
def test_scoped_seed_with_function_call() -> None:
|
||||
assert pick_a_number() == 31
|
||||
|
||||
with scoped_seed(37):
|
||||
assert pick_a_number() == 31
|
||||
|
||||
|
||||
def test_scoped_seed_with_callable_seed() -> None:
|
||||
with scoped_seed(pick_a_number):
|
||||
assert pick_a_number() == 31
|
||||
|
||||
def add_10(n: int) -> int:
|
||||
return n + 10
|
||||
|
||||
@scoped_seed(seed=add_10)
|
||||
def pick_a_number_greater_than_n_plus_10(n: int) -> int:
|
||||
return int(torch.randint(n, 100, (1,)).item())
|
||||
|
||||
assert pick_a_number_greater_than_n_plus_10(10) == 81
|
||||
|
||||
|
||||
def test_scoped_seed_restore_state() -> None:
|
||||
random.seed(37)
|
||||
with scoped_seed(42):
|
||||
random.randint(0, 100)
|
||||
assert random.randint(0, 100) == 87
|
||||
|
|
Loading…
Reference in a new issue