mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
fix training_utils requirements check
This commit is contained in:
parent
bf7852b88e
commit
7427c171f6
|
@ -4,6 +4,31 @@ from importlib.metadata import requires
|
|||
|
||||
from packaging.requirements import Requirement
|
||||
|
||||
refiners_requires = requires("refiners")
|
||||
assert refiners_requires is not None
|
||||
|
||||
# Some dependencies have different module names than their package names
|
||||
req_to_module: dict[str, str] = {
|
||||
"gitpython": "git",
|
||||
}
|
||||
|
||||
for dep in refiners_requires:
|
||||
req = Requirement(dep)
|
||||
marker = req.marker
|
||||
if marker is None or not marker.evaluate({"extra": "training"}):
|
||||
continue
|
||||
|
||||
module_name = req_to_module.get(req.name, req.name)
|
||||
|
||||
try:
|
||||
import_module(module_name)
|
||||
except ImportError:
|
||||
print(
|
||||
f"Some dependencies are missing: {req.name}. Please install refiners with the `training` extra, e.g. `pip install"
|
||||
" refiners[training]`",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||
from refiners.training_utils.clock import ClockConfig
|
||||
from refiners.training_utils.config import (
|
||||
|
@ -18,25 +43,6 @@ from refiners.training_utils.config import (
|
|||
from refiners.training_utils.trainer import Trainer, register_callback, register_model
|
||||
from refiners.training_utils.wandb import WandbConfig, WandbMixin
|
||||
|
||||
refiners_requires = requires("refiners")
|
||||
assert refiners_requires is not None
|
||||
|
||||
for dep in refiners_requires:
|
||||
req = Requirement(dep)
|
||||
marker = req.marker
|
||||
if marker is None or not marker.evaluate({"extra": "training"}):
|
||||
continue
|
||||
try:
|
||||
import_module(req.name)
|
||||
except ImportError:
|
||||
print(
|
||||
"Some dependencies are missing. Please install refiners with the `training` extra, e.g. `pip install"
|
||||
" refiners[training]`",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Trainer",
|
||||
"BaseConfig",
|
||||
|
|
|
@ -93,3 +93,12 @@ def test_scoped_seed_restore_state() -> None:
|
|||
with scoped_seed(42):
|
||||
random.randint(0, 100)
|
||||
assert random.randint(0, 100) == 87
|
||||
|
||||
|
||||
def test_import_training_utils() -> None:
|
||||
try:
|
||||
import refiners.training_utils
|
||||
except ImportError:
|
||||
pytest.fail("Failed to import refiners.training_utils")
|
||||
|
||||
assert refiners.training_utils is not None
|
||||
|
|
Loading…
Reference in a new issue