diff --git a/pyproject.toml b/pyproject.toml index c136353..a38928a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "safetensors>=0.4.0", "pillow>=10.1.0", "jaxtyping>=0.2.23", + "packaging>=23.2", ] readme = "README.md" requires-python = ">= 3.10" diff --git a/requirements.lock b/requirements.lock index bdaaf98..a71913a 100644 --- a/requirements.lock +++ b/requirements.lock @@ -53,7 +53,7 @@ nvidia-nvjitlink-cu12==12.3.101 nvidia-nvtx-cu12==12.1.105 opencv-python==4.8.1.78 packaging==23.2 -pandas==2.1.3 +pandas==2.1.4 pillow==10.1.0 piq==0.8.0 prodigyopt==1.0 diff --git a/src/refiners/training_utils/__init__.py b/src/refiners/training_utils/__init__.py index e714f33..66971aa 100644 --- a/src/refiners/training_utils/__init__.py +++ b/src/refiners/training_utils/__init__.py @@ -1,13 +1,18 @@ from importlib import import_module from importlib.metadata import requires +from packaging.requirements import Requirement import sys refiners_requires = requires("refiners") assert refiners_requires is not None -for dep in filter(lambda r: r.endswith('extra == "training"'), refiners_requires): +for dep in refiners_requires: + req = Requirement(dep) + marker = req.marker + if marker is None or not marker.evaluate({"extra": "training"}): + continue try: - import_module(dep.split(" ")[0]) + import_module(req.name) except ImportError: print( "Some dependencies are missing. Please install refiners with the `training` extra, e.g. `pip install"