mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
fix training_utils requirements check
This commit is contained in:
parent
bf7852b88e
commit
7427c171f6
|
@ -4,6 +4,31 @@ from importlib.metadata import requires
|
||||||
|
|
||||||
from packaging.requirements import Requirement
|
from packaging.requirements import Requirement
|
||||||
|
|
||||||
|
refiners_requires = requires("refiners")
|
||||||
|
assert refiners_requires is not None
|
||||||
|
|
||||||
|
# Some dependencies have different module names than their package names
|
||||||
|
req_to_module: dict[str, str] = {
|
||||||
|
"gitpython": "git",
|
||||||
|
}
|
||||||
|
|
||||||
|
for dep in refiners_requires:
|
||||||
|
req = Requirement(dep)
|
||||||
|
marker = req.marker
|
||||||
|
if marker is None or not marker.evaluate({"extra": "training"}):
|
||||||
|
continue
|
||||||
|
|
||||||
|
module_name = req_to_module.get(req.name, req.name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import_module(module_name)
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
f"Some dependencies are missing: {req.name}. Please install refiners with the `training` extra, e.g. `pip install"
|
||||||
|
" refiners[training]`",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||||
from refiners.training_utils.clock import ClockConfig
|
from refiners.training_utils.clock import ClockConfig
|
||||||
from refiners.training_utils.config import (
|
from refiners.training_utils.config import (
|
||||||
|
@ -18,25 +43,6 @@ from refiners.training_utils.config import (
|
||||||
from refiners.training_utils.trainer import Trainer, register_callback, register_model
|
from refiners.training_utils.trainer import Trainer, register_callback, register_model
|
||||||
from refiners.training_utils.wandb import WandbConfig, WandbMixin
|
from refiners.training_utils.wandb import WandbConfig, WandbMixin
|
||||||
|
|
||||||
refiners_requires = requires("refiners")
|
|
||||||
assert refiners_requires is not None
|
|
||||||
|
|
||||||
for dep in refiners_requires:
|
|
||||||
req = Requirement(dep)
|
|
||||||
marker = req.marker
|
|
||||||
if marker is None or not marker.evaluate({"extra": "training"}):
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
import_module(req.name)
|
|
||||||
except ImportError:
|
|
||||||
print(
|
|
||||||
"Some dependencies are missing. Please install refiners with the `training` extra, e.g. `pip install"
|
|
||||||
" refiners[training]`",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Trainer",
|
"Trainer",
|
||||||
"BaseConfig",
|
"BaseConfig",
|
||||||
|
|
|
@ -93,3 +93,12 @@ def test_scoped_seed_restore_state() -> None:
|
||||||
with scoped_seed(42):
|
with scoped_seed(42):
|
||||||
random.randint(0, 100)
|
random.randint(0, 100)
|
||||||
assert random.randint(0, 100) == 87
|
assert random.randint(0, 100) == 87
|
||||||
|
|
||||||
|
|
||||||
|
def test_import_training_utils() -> None:
|
||||||
|
try:
|
||||||
|
import refiners.training_utils
|
||||||
|
except ImportError:
|
||||||
|
pytest.fail("Failed to import refiners.training_utils")
|
||||||
|
|
||||||
|
assert refiners.training_utils is not None
|
||||||
|
|
Loading…
Reference in a new issue