training_utils: fix extra detection

Requirements could be, e.g.:

    wandb (>=0.15.7,<0.16.0) ; extra == "training"

Or:

    wandb>=0.16.0; extra == 'training'

Follow up of 86c5497
This commit is contained in:
Cédric Deltheil 2023-12-08 18:55:59 +01:00 committed by Cédric Deltheil
parent 86c54977b9
commit 4fc5e427b8
3 changed files with 9 additions and 3 deletions

View file

@ -9,6 +9,7 @@ dependencies = [
"safetensors>=0.4.0",
"pillow>=10.1.0",
"jaxtyping>=0.2.23",
"packaging>=23.2",
]
readme = "README.md"
requires-python = ">= 3.10"

View file

@ -53,7 +53,7 @@ nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
opencv-python==4.8.1.78
packaging==23.2
pandas==2.1.3
pandas==2.1.4
pillow==10.1.0
piq==0.8.0
prodigyopt==1.0

View file

@ -1,13 +1,18 @@
from importlib import import_module
from importlib.metadata import requires
from packaging.requirements import Requirement
import sys
refiners_requires = requires("refiners")
assert refiners_requires is not None
for dep in filter(lambda r: r.endswith('extra == "training"'), refiners_requires):
for dep in refiners_requires:
req = Requirement(dep)
marker = req.marker
if marker is None or not marker.evaluate({"extra": "training"}):
continue
try:
import_module(dep.split(" ")[0])
import_module(req.name)
except ImportError:
print(
"Some dependencies are missing. Please install refiners with the `training` extra, e.g. `pip install"