fix training_utils requirements check

This commit is contained in:
limiteinductive 2024-04-17 15:17:47 +00:00 committed by Benjamin Trom
parent bf7852b88e
commit 7427c171f6
2 changed files with 34 additions and 19 deletions

View file

@ -4,6 +4,31 @@ from importlib.metadata import requires
from packaging.requirements import Requirement
refiners_requires = requires("refiners")
assert refiners_requires is not None
# Some dependencies have different module names than their package names
req_to_module: dict[str, str] = {
"gitpython": "git",
}
for dep in refiners_requires:
req = Requirement(dep)
marker = req.marker
if marker is None or not marker.evaluate({"extra": "training"}):
continue
module_name = req_to_module.get(req.name, req.name)
try:
import_module(module_name)
except ImportError:
print(
f"Some dependencies are missing: {req.name}. Please install refiners with the `training` extra, e.g. `pip install"
" refiners[training]`",
file=sys.stderr,
)
sys.exit(1)
from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.config import (
@ -18,25 +43,6 @@ from refiners.training_utils.config import (
from refiners.training_utils.trainer import Trainer, register_callback, register_model
from refiners.training_utils.wandb import WandbConfig, WandbMixin
refiners_requires = requires("refiners")
assert refiners_requires is not None
for dep in refiners_requires:
req = Requirement(dep)
marker = req.marker
if marker is None or not marker.evaluate({"extra": "training"}):
continue
try:
import_module(req.name)
except ImportError:
print(
"Some dependencies are missing. Please install refiners with the `training` extra, e.g. `pip install"
" refiners[training]`",
file=sys.stderr,
)
sys.exit(1)
__all__ = [
"Trainer",
"BaseConfig",

View file

@ -93,3 +93,12 @@ def test_scoped_seed_restore_state() -> None:
with scoped_seed(42):
random.randint(0, 100)
assert random.randint(0, 100) == 87
def test_import_training_utils() -> None:
try:
import refiners.training_utils
except ImportError:
pytest.fail("Failed to import refiners.training_utils")
assert refiners.training_utils is not None