diff --git a/src/refiners/training_utils/common.py b/src/refiners/training_utils/common.py index abfaaeb..1fe9e1d 100644 --- a/src/refiners/training_utils/common.py +++ b/src/refiners/training_utils/common.py @@ -1,7 +1,6 @@ import random from dataclasses import dataclass from enum import Enum -from functools import wraps from typing import Any, Callable, Iterable import numpy as np @@ -44,38 +43,44 @@ def seed_everything(seed: int | None = None) -> None: cuda.manual_seed_all(seed=seed) -def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., Callable[..., Any]]: +class scoped_seed: """ - Decorator for setting a random seed within the scope of a function. + Context manager and decorator to set a fixed seed within a specific scope. - This decorator sets the random seed for Python's built-in `random` module, - `numpy`, and `torch` and `torch.cuda` at the beginning of the decorated function. After the - function is executed, it restores the state of the random number generators - to what it was before the function was called. This is useful for ensuring - reproducibility for specific parts of the code without affecting randomness - elsewhere. + The seed can be provided directly or as a callable that takes the same arguments + as the decorated function. Supports setting the seed for random, numpy, torch, + and torch.cuda modules. """ - def decorator(func: Callable[..., Any]) -> Callable[..., Any]: - @wraps(func) + def __init__(self, seed: int | Callable[..., int] | None = None): + self.seed = seed + self.actual_seed: int | None = None + + def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]: def inner_wrapper(*args: Any, **kwargs: Any) -> Any: - random_state = random.getstate() - numpy_state = np.random.get_state() - torch_state = torch.get_rng_state() - cuda_torch_state = cuda.get_rng_state() - actual_seed = seed(*args) if callable(seed) else seed - seed_everything(seed=actual_seed) - result = func(*args, **kwargs) - logger.trace(f"Restoring previous seed state") - random.setstate(random_state) - np.random.set_state(numpy_state) - torch.set_rng_state(torch_state) - cuda.set_rng_state(cuda_torch_state) - return result + self.actual_seed = self.seed(*args, **kwargs) if callable(self.seed) else self.seed + with self: + return func(*args, **kwargs) return inner_wrapper - return decorator + def __enter__(self) -> None: + if self.actual_seed is None: + seed = self.seed() if callable(self.seed) else self.seed + else: + seed = self.actual_seed + self.random_state = random.getstate() + self.numpy_state = np.random.get_state() + self.torch_state = torch.get_rng_state() + self.cuda_torch_state = cuda.get_rng_state() + seed_everything(seed) + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + logger.trace(f"Restoring previous seed state") + random.setstate(self.random_state) + np.random.set_state(self.numpy_state) + torch.set_rng_state(self.torch_state) + cuda.set_rng_state(self.cuda_torch_state) class TimeUnit(str, Enum): diff --git a/tests/training_utils/test_common.py b/tests/training_utils/test_common.py index db036f6..c8e60ba 100644 --- a/tests/training_utils/test_common.py +++ b/tests/training_utils/test_common.py @@ -1,6 +1,9 @@ -import pytest +import random -from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field +import pytest +import torch + +from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field, scoped_seed @pytest.mark.parametrize( @@ -31,3 +34,62 @@ def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: T def test_parse_number_unit_field_invalid_input(invalid_input: TimeValueInput): with pytest.raises(ValueError): parse_number_unit_field(invalid_input) + + +@scoped_seed(seed=37) +def pick_a_number() -> int: + return int(torch.randint(0, 100, (1,)).item()) + + +@pytest.mark.parametrize( + "seed, expected_output", + [ + (42, 42), + (37, 31), + (0, 44), + ], +) +def test_scoped_seed_with_specific_seed(seed: int, expected_output: int) -> None: + with scoped_seed(seed): + assert torch.randint(0, 100, (1,)).item() == expected_output + + +@pytest.mark.parametrize( + "seed, expected_output", + [ + (42, 81), + (37, 87), + (0, 49), + ], +) +def test_scoped_seed_with_random_module(seed: int, expected_output: int) -> None: + with scoped_seed(seed): + assert random.randint(0, 100) == expected_output + + +def test_scoped_seed_with_function_call() -> None: + assert pick_a_number() == 31 + + with scoped_seed(37): + assert pick_a_number() == 31 + + +def test_scoped_seed_with_callable_seed() -> None: + with scoped_seed(pick_a_number): + assert pick_a_number() == 31 + + def add_10(n: int) -> int: + return n + 10 + + @scoped_seed(seed=add_10) + def pick_a_number_greater_than_n_plus_10(n: int) -> int: + return int(torch.randint(n, 100, (1,)).item()) + + assert pick_a_number_greater_than_n_plus_10(10) == 81 + + +def test_scoped_seed_restore_state() -> None: + random.seed(37) + with scoped_seed(42): + random.randint(0, 100) + assert random.randint(0, 100) == 87