diff --git a/src/refiners/training_utils/__init__.py b/src/refiners/training_utils/__init__.py index 9fc4b58..236a40c 100644 --- a/src/refiners/training_utils/__init__.py +++ b/src/refiners/training_utils/__init__.py @@ -4,6 +4,31 @@ from importlib.metadata import requires from packaging.requirements import Requirement +refiners_requires = requires("refiners") +assert refiners_requires is not None + +# Some dependencies have different module names than their package names +req_to_module: dict[str, str] = { + "gitpython": "git", +} + +for dep in refiners_requires: + req = Requirement(dep) + marker = req.marker + if marker is None or not marker.evaluate({"extra": "training"}): + continue + + module_name = req_to_module.get(req.name, req.name) + + try: + import_module(module_name) + except ImportError: + print( + f"Some dependencies are missing: {req.name}. Please install refiners with the `training` extra, e.g. `pip install" + " refiners[training]`", + file=sys.stderr, + ) + sys.exit(1) from refiners.training_utils.callback import Callback, CallbackConfig from refiners.training_utils.clock import ClockConfig from refiners.training_utils.config import ( @@ -18,25 +43,6 @@ from refiners.training_utils.config import ( from refiners.training_utils.trainer import Trainer, register_callback, register_model from refiners.training_utils.wandb import WandbConfig, WandbMixin -refiners_requires = requires("refiners") -assert refiners_requires is not None - -for dep in refiners_requires: - req = Requirement(dep) - marker = req.marker - if marker is None or not marker.evaluate({"extra": "training"}): - continue - try: - import_module(req.name) - except ImportError: - print( - "Some dependencies are missing. Please install refiners with the `training` extra, e.g. `pip install" - " refiners[training]`", - file=sys.stderr, - ) - sys.exit(1) - - __all__ = [ "Trainer", "BaseConfig", diff --git a/tests/training_utils/test_common.py b/tests/training_utils/test_common.py index c8e60ba..f490424 100644 --- a/tests/training_utils/test_common.py +++ b/tests/training_utils/test_common.py @@ -93,3 +93,12 @@ def test_scoped_seed_restore_state() -> None: with scoped_seed(42): random.randint(0, 100) assert random.randint(0, 100) == 87 + + +def test_import_training_utils() -> None: + try: + import refiners.training_utils + except ImportError: + pytest.fail("Failed to import refiners.training_utils") + + assert refiners.training_utils is not None