feat: ugly hack to have multiple classes
This commit is contained in:
parent
57853be03e
commit
eb3dabe8d7
48
poetry.lock
generated
48
poetry.lock
generated
|
@ -24,7 +24,7 @@ multidict = ">=4.5,<7.0"
|
||||||
yarl = ">=1.0,<2.0"
|
yarl = ">=1.0,<2.0"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
speedups = ["aiodns", "brotli", "cchardet"]
|
speedups = ["Brotli", "aiodns", "cchardet"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiosignal"
|
name = "aiosignal"
|
||||||
|
@ -97,9 +97,9 @@ optional = false
|
||||||
python-versions = ">=3.5"
|
python-versions = ">=3.5"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope-interface"]
|
dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope.interface"]
|
||||||
docs = ["furo", "sphinx", "sphinx-notfound-page", "zope-interface"]
|
docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"]
|
||||||
tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope-interface"]
|
tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"]
|
||||||
tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"]
|
tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -127,7 +127,7 @@ stevedore = ">=1.20.0"
|
||||||
[package.extras]
|
[package.extras]
|
||||||
test = ["beautifulsoup4 (>=4.8.0)", "coverage (>=4.5.4)", "fixtures (>=3.0.0)", "flake8 (>=4.0.0)", "pylint (==1.9.4)", "stestr (>=2.5.0)", "testscenarios (>=0.5.0)", "testtools (>=2.3.0)", "toml"]
|
test = ["beautifulsoup4 (>=4.8.0)", "coverage (>=4.5.4)", "fixtures (>=3.0.0)", "flake8 (>=4.0.0)", "pylint (==1.9.4)", "stestr (>=2.5.0)", "testscenarios (>=0.5.0)", "testtools (>=2.3.0)", "toml"]
|
||||||
toml = ["toml"]
|
toml = ["toml"]
|
||||||
yaml = ["pyyaml"]
|
yaml = ["PyYAML"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "black"
|
name = "black"
|
||||||
|
@ -352,7 +352,7 @@ Flake8 = ">=5,<6"
|
||||||
TOMLi = {version = "*", markers = "python_version < \"3.11\""}
|
TOMLi = {version = "*", markers = "python_version < \"3.11\""}
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
test = ["pytest", "pytest-cov"]
|
test = ["pyTest", "pyTest-cov"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "flatbuffers"
|
name = "flatbuffers"
|
||||||
|
@ -544,17 +544,17 @@ pillow = ">=8.3.2"
|
||||||
all-plugins = ["astropy", "av", "imageio-ffmpeg", "opencv-python", "psutil", "tifffile"]
|
all-plugins = ["astropy", "av", "imageio-ffmpeg", "opencv-python", "psutil", "tifffile"]
|
||||||
all-plugins-pypy = ["av", "imageio-ffmpeg", "psutil", "tifffile"]
|
all-plugins-pypy = ["av", "imageio-ffmpeg", "psutil", "tifffile"]
|
||||||
build = ["wheel"]
|
build = ["wheel"]
|
||||||
dev = ["black", "flake8", "fsspec", "invoke", "pytest", "pytest-cov"]
|
dev = ["black", "flake8", "fsspec[github]", "invoke", "pytest", "pytest-cov"]
|
||||||
docs = ["numpydoc", "pydata-sphinx-theme", "sphinx"]
|
docs = ["numpydoc", "pydata-sphinx-theme", "sphinx"]
|
||||||
ffmpeg = ["imageio-ffmpeg", "psutil"]
|
ffmpeg = ["imageio-ffmpeg", "psutil"]
|
||||||
fits = ["astropy"]
|
fits = ["astropy"]
|
||||||
full = ["astropy", "av", "black", "flake8", "fsspec", "gdal", "imageio-ffmpeg", "invoke", "itk", "numpydoc", "opencv-python", "psutil", "pydata-sphinx-theme", "pytest", "pytest-cov", "sphinx", "tifffile", "wheel"]
|
full = ["astropy", "av", "black", "flake8", "fsspec[github]", "gdal", "imageio-ffmpeg", "invoke", "itk", "numpydoc", "opencv-python", "psutil", "pydata-sphinx-theme", "pytest", "pytest-cov", "sphinx", "tifffile", "wheel"]
|
||||||
gdal = ["gdal"]
|
gdal = ["gdal"]
|
||||||
itk = ["itk"]
|
itk = ["itk"]
|
||||||
linting = ["black", "flake8"]
|
linting = ["black", "flake8"]
|
||||||
opencv = ["opencv-python"]
|
opencv = ["opencv-python"]
|
||||||
pyav = ["av"]
|
pyav = ["av"]
|
||||||
test = ["fsspec", "invoke", "pytest", "pytest-cov"]
|
test = ["fsspec[github]", "invoke", "pytest", "pytest-cov"]
|
||||||
tifffile = ["tifffile"]
|
tifffile = ["tifffile"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -571,11 +571,11 @@ zipp = ">=0.5"
|
||||||
[package.extras]
|
[package.extras]
|
||||||
docs = ["jaraco.packaging (>=9)", "rst.linker (>=1.9)", "sphinx"]
|
docs = ["jaraco.packaging (>=9)", "rst.linker (>=1.9)", "sphinx"]
|
||||||
perf = ["ipython"]
|
perf = ["ipython"]
|
||||||
testing = ["flufl-flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"]
|
testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ipykernel"
|
name = "ipykernel"
|
||||||
version = "6.15.2"
|
version = "6.15.3"
|
||||||
description = "IPython Kernel for Jupyter"
|
description = "IPython Kernel for Jupyter"
|
||||||
category = "dev"
|
category = "dev"
|
||||||
optional = false
|
optional = false
|
||||||
|
@ -728,11 +728,11 @@ torch = ">=1.7.1"
|
||||||
torchmetrics = ">=0.4.1"
|
torchmetrics = ">=0.4.1"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
dev = ["ale-py (>=0.7)", "atari-py (==0.2.6)", "atari-py (>=0.2.0,<0.3.0)", "box2d-py (>=2.3.0,<2.4.0)", "check-manifest", "codecov (>=2.1)", "flake8", "gym[atari] (>=0.17.2,<0.20.0)", "isort (>=5.6.4)", "jsonargparse", "matplotlib", "mypy (>=0.790)", "opencv-python-headless", "pillow", "pre-commit (>=1.0)", "pytest (>=6.0)", "pytest-cov (>2.10)", "scikit-learn (>=0.23)", "scipy", "sparseml", "torchvision (>=0.8.2)", "twine (>=3.2)", "wandb"]
|
dev = ["Pillow", "ale-py (>=0.7)", "atari-py (==0.2.6)", "atari-py (>=0.2.0,<0.3.0)", "box2d-py (>=2.3.0,<2.4.0)", "check-manifest", "codecov (>=2.1)", "flake8", "gym[atari] (>=0.17.2,<0.20.0)", "isort (>=5.6.4)", "jsonargparse[signatures]", "matplotlib", "mypy (>=0.790)", "opencv-python-headless", "pre-commit (>=1.0)", "pytest (>=6.0)", "pytest-cov (>2.10)", "scikit-learn (>=0.23)", "scipy", "sparseml", "torchvision (>=0.8.2)", "twine (>=3.2)", "wandb"]
|
||||||
extra = ["atari-py (>=0.2.0,<0.3.0)", "box2d-py (>=2.3.0,<2.4.0)", "gym[atari] (>=0.17.2,<0.20.0)", "matplotlib", "opencv-python-headless", "pillow", "scikit-learn (>=0.23)", "scipy", "torchvision (>=0.8.2)", "wandb"]
|
extra = ["Pillow", "atari-py (>=0.2.0,<0.3.0)", "box2d-py (>=2.3.0,<2.4.0)", "gym[atari] (>=0.17.2,<0.20.0)", "matplotlib", "opencv-python-headless", "scikit-learn (>=0.23)", "scipy", "torchvision (>=0.8.2)", "wandb"]
|
||||||
loggers = ["matplotlib", "scipy", "wandb"]
|
loggers = ["matplotlib", "scipy", "wandb"]
|
||||||
models = ["atari-py (>=0.2.0,<0.3.0)", "box2d-py (>=2.3.0,<2.4.0)", "gym[atari] (>=0.17.2,<0.20.0)", "opencv-python-headless", "pillow", "scikit-learn (>=0.23)", "torchvision (>=0.8.2)"]
|
models = ["Pillow", "atari-py (>=0.2.0,<0.3.0)", "box2d-py (>=2.3.0,<2.4.0)", "gym[atari] (>=0.17.2,<0.20.0)", "opencv-python-headless", "scikit-learn (>=0.23)", "torchvision (>=0.8.2)"]
|
||||||
test = ["ale-py (>=0.7)", "atari-py (==0.2.6)", "check-manifest", "codecov (>=2.1)", "flake8", "isort (>=5.6.4)", "jsonargparse", "mypy (>=0.790)", "pre-commit (>=1.0)", "pytest (>=6.0)", "pytest-cov (>2.10)", "scikit-learn (>=0.23)", "sparseml", "twine (>=3.2)"]
|
test = ["ale-py (>=0.7)", "atari-py (==0.2.6)", "check-manifest", "codecov (>=2.1)", "flake8", "isort (>=5.6.4)", "jsonargparse[signatures]", "mypy (>=0.790)", "pre-commit (>=1.0)", "pytest (>=6.0)", "pytest-cov (>2.10)", "scikit-learn (>=0.23)", "sparseml", "twine (>=3.2)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "markdown"
|
name = "markdown"
|
||||||
|
@ -1275,10 +1275,10 @@ tqdm = ">=4.57.0"
|
||||||
typing-extensions = ">=4.0.0"
|
typing-extensions = ">=4.0.0"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
all = ["cloudpickle (>=1.3)", "codecov (>=2.1)", "comet-ml (>=3.1.12)", "coverage (>=6.4)", "deepspeed (>=0.6.0)", "fairscale (>=0.4.5)", "fastapi", "gcsfs (>=2021.5.0)", "gym[classic_control] (>=0.17.0)", "hivemind (>=1.0.1)", "horovod (>=0.21.2,!=0.24.0)", "hydra-core (>=1.0.5)", "ipython", "jsonargparse[signatures] (>=4.12.0)", "matplotlib (>3.1)", "mlflow (>=1.0.0)", "mypy (==0.971)", "neptune-client (>=0.10.0)", "omegaconf (>=2.0.5)", "onnxruntime", "pandas (>1.0)", "pre-commit (>=1.0)", "protobuf (<=3.20.1)", "psutil", "pytest (>=7.0)", "pytest-cov", "pytest-forked", "pytest-rerunfailures (>=10.2)", "rich (>=10.14.0,!=10.15.0.a)", "scikit-learn (>0.22.1)", "torchtext (>=0.10)", "torchvision (>=0.10)", "uvicorn", "wandb (>=0.10.22)"]
|
all = ["cloudpickle (>=1.3)", "codecov (>=2.1)", "comet-ml (>=3.1.12)", "coverage (>=6.4)", "deepspeed (>=0.6.0)", "fairscale (>=0.4.5)", "fastapi", "gcsfs (>=2021.5.0)", "gym[classic_control] (>=0.17.0)", "hivemind (>=1.0.1)", "horovod (>=0.21.2,!=0.24.0)", "hydra-core (>=1.0.5)", "ipython[all]", "jsonargparse[signatures] (>=4.12.0)", "matplotlib (>3.1)", "mlflow (>=1.0.0)", "mypy (==0.971)", "neptune-client (>=0.10.0)", "omegaconf (>=2.0.5)", "onnxruntime", "pandas (>1.0)", "pre-commit (>=1.0)", "protobuf (<=3.20.1)", "psutil", "pytest (>=7.0)", "pytest-cov", "pytest-forked", "pytest-rerunfailures (>=10.2)", "rich (>=10.14.0,!=10.15.0.a)", "scikit-learn (>0.22.1)", "torchtext (>=0.10)", "torchvision (>=0.10)", "uvicorn", "wandb (>=0.10.22)"]
|
||||||
deepspeed = ["deepspeed (>=0.6.0)"]
|
deepspeed = ["deepspeed (>=0.6.0)"]
|
||||||
dev = ["cloudpickle (>=1.3)", "codecov (>=2.1)", "comet-ml (>=3.1.12)", "coverage (>=6.4)", "fastapi", "gcsfs (>=2021.5.0)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.12.0)", "matplotlib (>3.1)", "mlflow (>=1.0.0)", "mypy (==0.971)", "neptune-client (>=0.10.0)", "omegaconf (>=2.0.5)", "onnxruntime", "pandas (>1.0)", "pre-commit (>=1.0)", "protobuf (<=3.20.1)", "psutil", "pytest (>=7.0)", "pytest-cov", "pytest-forked", "pytest-rerunfailures (>=10.2)", "rich (>=10.14.0,!=10.15.0.a)", "scikit-learn (>0.22.1)", "torchtext (>=0.10)", "uvicorn", "wandb (>=0.10.22)"]
|
dev = ["cloudpickle (>=1.3)", "codecov (>=2.1)", "comet-ml (>=3.1.12)", "coverage (>=6.4)", "fastapi", "gcsfs (>=2021.5.0)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.12.0)", "matplotlib (>3.1)", "mlflow (>=1.0.0)", "mypy (==0.971)", "neptune-client (>=0.10.0)", "omegaconf (>=2.0.5)", "onnxruntime", "pandas (>1.0)", "pre-commit (>=1.0)", "protobuf (<=3.20.1)", "psutil", "pytest (>=7.0)", "pytest-cov", "pytest-forked", "pytest-rerunfailures (>=10.2)", "rich (>=10.14.0,!=10.15.0.a)", "scikit-learn (>0.22.1)", "torchtext (>=0.10)", "uvicorn", "wandb (>=0.10.22)"]
|
||||||
examples = ["gym[classic_control] (>=0.17.0)", "ipython", "torchvision (>=0.10)"]
|
examples = ["gym[classic_control] (>=0.17.0)", "ipython[all]", "torchvision (>=0.10)"]
|
||||||
extra = ["gcsfs (>=2021.5.0)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.12.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "protobuf (<=3.20.1)", "rich (>=10.14.0,!=10.15.0.a)", "torchtext (>=0.10)"]
|
extra = ["gcsfs (>=2021.5.0)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.12.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "protobuf (<=3.20.1)", "rich (>=10.14.0,!=10.15.0.a)", "torchtext (>=0.10)"]
|
||||||
fairscale = ["fairscale (>=0.4.5)"]
|
fairscale = ["fairscale (>=0.4.5)"]
|
||||||
hivemind = ["hivemind (>=1.0.1)"]
|
hivemind = ["hivemind (>=1.0.1)"]
|
||||||
|
@ -1421,7 +1421,7 @@ tifffile = ">=2019.7.26"
|
||||||
[package.extras]
|
[package.extras]
|
||||||
data = ["pooch (>=1.3.0)"]
|
data = ["pooch (>=1.3.0)"]
|
||||||
docs = ["cloudpickle (>=0.2.1)", "dask[array] (>=0.15.0,!=2.17.0)", "ipywidgets", "kaleido", "matplotlib (>=3.3)", "myst-parser", "numpydoc (>=1.0)", "pandas (>=0.23.0)", "plotly (>=4.14.0)", "pooch (>=1.3.0)", "pytest-runner", "scikit-learn", "seaborn (>=0.7.1)", "sphinx (>=1.8)", "sphinx-copybutton", "sphinx-gallery (>=0.10.1)", "tifffile (>=2020.5.30)"]
|
docs = ["cloudpickle (>=0.2.1)", "dask[array] (>=0.15.0,!=2.17.0)", "ipywidgets", "kaleido", "matplotlib (>=3.3)", "myst-parser", "numpydoc (>=1.0)", "pandas (>=0.23.0)", "plotly (>=4.14.0)", "pooch (>=1.3.0)", "pytest-runner", "scikit-learn", "seaborn (>=0.7.1)", "sphinx (>=1.8)", "sphinx-copybutton", "sphinx-gallery (>=0.10.1)", "tifffile (>=2020.5.30)"]
|
||||||
optional = ["astropy (>=3.1.2)", "cloudpickle (>=0.2.1)", "dask[array] (>=1.0.0,!=2.17.0)", "matplotlib (>=3.0.3)", "pooch (>=1.3.0)", "pyamg", "qtpy", "simpleitk"]
|
optional = ["SimpleITK", "astropy (>=3.1.2)", "cloudpickle (>=0.2.1)", "dask[array] (>=1.0.0,!=2.17.0)", "matplotlib (>=3.0.3)", "pooch (>=1.3.0)", "pyamg", "qtpy"]
|
||||||
test = ["asv", "codecov", "flake8", "matplotlib (>=3.0.3)", "pooch (>=1.3.0)", "pytest (>=5.2.0)", "pytest-cov (>=2.7.0)", "pytest-faulthandler", "pytest-localserver"]
|
test = ["asv", "codecov", "flake8", "matplotlib (>=3.0.3)", "pooch (>=1.3.0)", "pytest (>=5.2.0)", "pytest-cov (>=2.7.0)", "pytest-faulthandler", "pytest-localserver"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1508,8 +1508,8 @@ python-versions = ">=3.7"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
|
docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
|
||||||
testing = ["build", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mock", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
|
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mock", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
|
||||||
testing-integration = ["build", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
|
testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "setuptools-scm"
|
name = "setuptools-scm"
|
||||||
|
@ -1899,12 +1899,12 @@ python-versions = ">=3.7"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
docs = ["jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx"]
|
docs = ["jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx"]
|
||||||
testing = ["func-timeout", "jaraco-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
|
testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = ">=3.8,<3.11"
|
python-versions = ">=3.8,<3.11"
|
||||||
content-hash = "8eba042ca7188b409be6ea812f315b9aa6964791195e21ea8ea66655bb4c4fab"
|
content-hash = "5f210ee2c6b72cc8e0021f082cc9c7a985528a82c437dca40086704faed8d2f6"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
absl-py = [
|
absl-py = [
|
||||||
|
@ -2357,8 +2357,8 @@ importlib-metadata = [
|
||||||
{file = "importlib_metadata-4.12.0.tar.gz", hash = "sha256:637245b8bab2b6502fcbc752cc4b7a6f6243bb02b31c5c26156ad103d3d45670"},
|
{file = "importlib_metadata-4.12.0.tar.gz", hash = "sha256:637245b8bab2b6502fcbc752cc4b7a6f6243bb02b31c5c26156ad103d3d45670"},
|
||||||
]
|
]
|
||||||
ipykernel = [
|
ipykernel = [
|
||||||
{file = "ipykernel-6.15.2-py3-none-any.whl", hash = "sha256:59183ef833b82c72211aace3fb48fd20eae8e2d0cae475f3d5c39d4a688e81ec"},
|
{file = "ipykernel-6.15.3-py3-none-any.whl", hash = "sha256:befe3736944b21afec8e832725e9a45f254c8bd9afc40b61d6661c97e45aff5a"},
|
||||||
{file = "ipykernel-6.15.2.tar.gz", hash = "sha256:e7481083b438609c9c8a22d6362e8e1bc6ec94ba0741b666941e634f2d61bdf3"},
|
{file = "ipykernel-6.15.3.tar.gz", hash = "sha256:b81d57b0e171670844bf29cdc11562b1010d3da87115c4513e0ee660a8368765"},
|
||||||
]
|
]
|
||||||
ipython = [
|
ipython = [
|
||||||
{file = "ipython-8.5.0-py3-none-any.whl", hash = "sha256:6f090e29ab8ef8643e521763a4f1f39dc3914db643122b1e9d3328ff2e43ada2"},
|
{file = "ipython-8.5.0-py3-none-any.whl", hash = "sha256:6f090e29ab8ef8643e521763a4f1f39dc3914db643122b1e9d3328ff2e43ada2"},
|
||||||
|
|
|
@ -23,7 +23,7 @@ wandb = "^0.13.2"
|
||||||
optional = true
|
optional = true
|
||||||
|
|
||||||
[tool.poetry.group.notebooks.dependencies]
|
[tool.poetry.group.notebooks.dependencies]
|
||||||
ipykernel = "^6.15.2"
|
ipykernel = "^6.15.3"
|
||||||
matplotlib = "^3.5.3"
|
matplotlib = "^3.5.3"
|
||||||
onnx = "^1.12.0"
|
onnx = "^1.12.0"
|
||||||
onnxruntime = "^1.12.1"
|
onnxruntime = "^1.12.1"
|
||||||
|
|
|
@ -164,6 +164,7 @@ class LabeledDataset(Dataset):
|
||||||
# create bboxes from masks (pascal format)
|
# create bboxes from masks (pascal format)
|
||||||
num_objs = len(obj_ids)
|
num_objs = len(obj_ids)
|
||||||
bboxes = []
|
bboxes = []
|
||||||
|
labels = []
|
||||||
for i in range(num_objs):
|
for i in range(num_objs):
|
||||||
pos = np.where(masks[i])
|
pos = np.where(masks[i])
|
||||||
xmin = np.min(pos[1])
|
xmin = np.min(pos[1])
|
||||||
|
@ -171,10 +172,11 @@ class LabeledDataset(Dataset):
|
||||||
ymin = np.min(pos[0])
|
ymin = np.min(pos[0])
|
||||||
ymax = np.max(pos[0])
|
ymax = np.max(pos[0])
|
||||||
bboxes.append([xmin, ymin, xmax, ymax])
|
bboxes.append([xmin, ymin, xmax, ymax])
|
||||||
|
labels.append(2 if mask[(ymax + ymin) // 2, (xmax + xmin) // 2] > 127 else 1)
|
||||||
|
|
||||||
# convert arrays for albumentations
|
# convert arrays for albumentations
|
||||||
bboxes = torch.as_tensor(bboxes, dtype=torch.int64)
|
bboxes = torch.as_tensor(bboxes, dtype=torch.int64)
|
||||||
labels = torch.ones((num_objs,), dtype=torch.int64) # assume there is only one class (id=1)
|
labels = torch.as_tensor(labels, dtype=torch.int64)
|
||||||
masks = list(np.asarray(masks))
|
masks = list(np.asarray(masks))
|
||||||
|
|
||||||
if self.transforms is not None:
|
if self.transforms is not None:
|
||||||
|
|
|
@ -35,7 +35,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Create Network
|
# Create Network
|
||||||
module = MRCNNModule(
|
module = MRCNNModule(
|
||||||
n_classes=2,
|
n_classes=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
# load checkpoint
|
# load checkpoint
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
DIR_TRAIN_IMG:
|
DIR_TRAIN_IMG:
|
||||||
value: "/media/disk1/lfainsin/TRAIN_prerender/"
|
value: "/media/disk1/lfainsin/TRAIN_prerender_old/"
|
||||||
DIR_VALID_IMG:
|
DIR_VALID_IMG:
|
||||||
value: "/media/disk1/lfainsin/TEST_tmp_mrcnn/"
|
value: "/media/disk1/lfainsin/TEST_tmp_mrcnn/"
|
||||||
# DIR_SPHERE:
|
# DIR_SPHERE:
|
||||||
|
|
Loading…
Reference in a new issue