From 4fc5e427b8f0907b4c1cd0aee54d273873430f90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Fri, 8 Dec 2023 18:55:59 +0100 Subject: [PATCH] training_utils: fix extra detection Requirements could be, e.g.: wandb (>=0.15.7,<0.16.0) ; extra == "training" Or: wandb>=0.16.0; extra == 'training' Follow up of 86c5497 --- pyproject.toml | 1 + requirements.lock | 2 +- src/refiners/training_utils/__init__.py | 9 +++++++-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c136353..a38928a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "safetensors>=0.4.0", "pillow>=10.1.0", "jaxtyping>=0.2.23", + "packaging>=23.2", ] readme = "README.md" requires-python = ">= 3.10" diff --git a/requirements.lock b/requirements.lock index bdaaf98..a71913a 100644 --- a/requirements.lock +++ b/requirements.lock @@ -53,7 +53,7 @@ nvidia-nvjitlink-cu12==12.3.101 nvidia-nvtx-cu12==12.1.105 opencv-python==4.8.1.78 packaging==23.2 -pandas==2.1.3 +pandas==2.1.4 pillow==10.1.0 piq==0.8.0 prodigyopt==1.0 diff --git a/src/refiners/training_utils/__init__.py b/src/refiners/training_utils/__init__.py index e714f33..66971aa 100644 --- a/src/refiners/training_utils/__init__.py +++ b/src/refiners/training_utils/__init__.py @@ -1,13 +1,18 @@ from importlib import import_module from importlib.metadata import requires +from packaging.requirements import Requirement import sys refiners_requires = requires("refiners") assert refiners_requires is not None -for dep in filter(lambda r: r.endswith('extra == "training"'), refiners_requires): +for dep in refiners_requires: + req = Requirement(dep) + marker = req.marker + if marker is None or not marker.evaluate({"extra": "training"}): + continue try: - import_module(dep.split(" ")[0]) + import_module(req.name) except ImportError: print( "Some dependencies are missing. Please install refiners with the `training` extra, e.g. `pip install"