turn scoped_seed into a context manager

This commit is contained in:
limiteinductive 2024-04-12 17:13:48 +00:00 committed by Benjamin Trom
parent 64692c3b5b
commit b9b999ccfe
2 changed files with 94 additions and 27 deletions

View file

@ -1,7 +1,6 @@
import random
from dataclasses import dataclass
from enum import Enum
from functools import wraps
from typing import Any, Callable, Iterable
import numpy as np
@ -44,38 +43,44 @@ def seed_everything(seed: int | None = None) -> None:
cuda.manual_seed_all(seed=seed)
def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., Callable[..., Any]]:
class scoped_seed:
"""
Decorator for setting a random seed within the scope of a function.
Context manager and decorator to set a fixed seed within a specific scope.
This decorator sets the random seed for Python's built-in `random` module,
`numpy`, and `torch` and `torch.cuda` at the beginning of the decorated function. After the
function is executed, it restores the state of the random number generators
to what it was before the function was called. This is useful for ensuring
reproducibility for specific parts of the code without affecting randomness
elsewhere.
The seed can be provided directly or as a callable that takes the same arguments
as the decorated function. Supports setting the seed for random, numpy, torch,
and torch.cuda modules.
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def __init__(self, seed: int | Callable[..., int] | None = None):
self.seed = seed
self.actual_seed: int | None = None
def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
random_state = random.getstate()
numpy_state = np.random.get_state()
torch_state = torch.get_rng_state()
cuda_torch_state = cuda.get_rng_state()
actual_seed = seed(*args) if callable(seed) else seed
seed_everything(seed=actual_seed)
result = func(*args, **kwargs)
logger.trace(f"Restoring previous seed state")
random.setstate(random_state)
np.random.set_state(numpy_state)
torch.set_rng_state(torch_state)
cuda.set_rng_state(cuda_torch_state)
return result
self.actual_seed = self.seed(*args, **kwargs) if callable(self.seed) else self.seed
with self:
return func(*args, **kwargs)
return inner_wrapper
return decorator
def __enter__(self) -> None:
if self.actual_seed is None:
seed = self.seed() if callable(self.seed) else self.seed
else:
seed = self.actual_seed
self.random_state = random.getstate()
self.numpy_state = np.random.get_state()
self.torch_state = torch.get_rng_state()
self.cuda_torch_state = cuda.get_rng_state()
seed_everything(seed)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
logger.trace(f"Restoring previous seed state")
random.setstate(self.random_state)
np.random.set_state(self.numpy_state)
torch.set_rng_state(self.torch_state)
cuda.set_rng_state(self.cuda_torch_state)
class TimeUnit(str, Enum):

View file

@ -1,6 +1,9 @@
import pytest
import random
from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field
import pytest
import torch
from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field, scoped_seed
@pytest.mark.parametrize(
@ -31,3 +34,62 @@ def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: T
def test_parse_number_unit_field_invalid_input(invalid_input: TimeValueInput):
with pytest.raises(ValueError):
parse_number_unit_field(invalid_input)
@scoped_seed(seed=37)
def pick_a_number() -> int:
return int(torch.randint(0, 100, (1,)).item())
@pytest.mark.parametrize(
"seed, expected_output",
[
(42, 42),
(37, 31),
(0, 44),
],
)
def test_scoped_seed_with_specific_seed(seed: int, expected_output: int) -> None:
with scoped_seed(seed):
assert torch.randint(0, 100, (1,)).item() == expected_output
@pytest.mark.parametrize(
"seed, expected_output",
[
(42, 81),
(37, 87),
(0, 49),
],
)
def test_scoped_seed_with_random_module(seed: int, expected_output: int) -> None:
with scoped_seed(seed):
assert random.randint(0, 100) == expected_output
def test_scoped_seed_with_function_call() -> None:
assert pick_a_number() == 31
with scoped_seed(37):
assert pick_a_number() == 31
def test_scoped_seed_with_callable_seed() -> None:
with scoped_seed(pick_a_number):
assert pick_a_number() == 31
def add_10(n: int) -> int:
return n + 10
@scoped_seed(seed=add_10)
def pick_a_number_greater_than_n_plus_10(n: int) -> int:
return int(torch.randint(n, 100, (1,)).item())
assert pick_a_number_greater_than_n_plus_10(10) == 81
def test_scoped_seed_restore_state() -> None:
random.seed(37)
with scoped_seed(42):
random.randint(0, 100)
assert random.randint(0, 100) == 87