feat: ugly hack to have multiple classes

This commit is contained in:
Laurent Fainsin 2022-09-15 10:40:04 +02:00
parent 57853be03e
commit eb3dabe8d7
5 changed files with 30 additions and 28 deletions

48
poetry.lock generated
View file

@ -24,7 +24,7 @@ multidict = ">=4.5,<7.0"
yarl = ">=1.0,<2.0" yarl = ">=1.0,<2.0"
[package.extras] [package.extras]
speedups = ["aiodns", "brotli", "cchardet"] speedups = ["Brotli", "aiodns", "cchardet"]
[[package]] [[package]]
name = "aiosignal" name = "aiosignal"
@ -97,9 +97,9 @@ optional = false
python-versions = ">=3.5" python-versions = ">=3.5"
[package.extras] [package.extras]
dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope-interface"] dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope.interface"]
docs = ["furo", "sphinx", "sphinx-notfound-page", "zope-interface"] docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"]
tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope-interface"] tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"]
tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"]
[[package]] [[package]]
@ -127,7 +127,7 @@ stevedore = ">=1.20.0"
[package.extras] [package.extras]
test = ["beautifulsoup4 (>=4.8.0)", "coverage (>=4.5.4)", "fixtures (>=3.0.0)", "flake8 (>=4.0.0)", "pylint (==1.9.4)", "stestr (>=2.5.0)", "testscenarios (>=0.5.0)", "testtools (>=2.3.0)", "toml"] test = ["beautifulsoup4 (>=4.8.0)", "coverage (>=4.5.4)", "fixtures (>=3.0.0)", "flake8 (>=4.0.0)", "pylint (==1.9.4)", "stestr (>=2.5.0)", "testscenarios (>=0.5.0)", "testtools (>=2.3.0)", "toml"]
toml = ["toml"] toml = ["toml"]
yaml = ["pyyaml"] yaml = ["PyYAML"]
[[package]] [[package]]
name = "black" name = "black"
@ -352,7 +352,7 @@ Flake8 = ">=5,<6"
TOMLi = {version = "*", markers = "python_version < \"3.11\""} TOMLi = {version = "*", markers = "python_version < \"3.11\""}
[package.extras] [package.extras]
test = ["pytest", "pytest-cov"] test = ["pyTest", "pyTest-cov"]
[[package]] [[package]]
name = "flatbuffers" name = "flatbuffers"
@ -544,17 +544,17 @@ pillow = ">=8.3.2"
all-plugins = ["astropy", "av", "imageio-ffmpeg", "opencv-python", "psutil", "tifffile"] all-plugins = ["astropy", "av", "imageio-ffmpeg", "opencv-python", "psutil", "tifffile"]
all-plugins-pypy = ["av", "imageio-ffmpeg", "psutil", "tifffile"] all-plugins-pypy = ["av", "imageio-ffmpeg", "psutil", "tifffile"]
build = ["wheel"] build = ["wheel"]
dev = ["black", "flake8", "fsspec", "invoke", "pytest", "pytest-cov"] dev = ["black", "flake8", "fsspec[github]", "invoke", "pytest", "pytest-cov"]
docs = ["numpydoc", "pydata-sphinx-theme", "sphinx"] docs = ["numpydoc", "pydata-sphinx-theme", "sphinx"]
ffmpeg = ["imageio-ffmpeg", "psutil"] ffmpeg = ["imageio-ffmpeg", "psutil"]
fits = ["astropy"] fits = ["astropy"]
full = ["astropy", "av", "black", "flake8", "fsspec", "gdal", "imageio-ffmpeg", "invoke", "itk", "numpydoc", "opencv-python", "psutil", "pydata-sphinx-theme", "pytest", "pytest-cov", "sphinx", "tifffile", "wheel"] full = ["astropy", "av", "black", "flake8", "fsspec[github]", "gdal", "imageio-ffmpeg", "invoke", "itk", "numpydoc", "opencv-python", "psutil", "pydata-sphinx-theme", "pytest", "pytest-cov", "sphinx", "tifffile", "wheel"]
gdal = ["gdal"] gdal = ["gdal"]
itk = ["itk"] itk = ["itk"]
linting = ["black", "flake8"] linting = ["black", "flake8"]
opencv = ["opencv-python"] opencv = ["opencv-python"]
pyav = ["av"] pyav = ["av"]
test = ["fsspec", "invoke", "pytest", "pytest-cov"] test = ["fsspec[github]", "invoke", "pytest", "pytest-cov"]
tifffile = ["tifffile"] tifffile = ["tifffile"]
[[package]] [[package]]
@ -571,11 +571,11 @@ zipp = ">=0.5"
[package.extras] [package.extras]
docs = ["jaraco.packaging (>=9)", "rst.linker (>=1.9)", "sphinx"] docs = ["jaraco.packaging (>=9)", "rst.linker (>=1.9)", "sphinx"]
perf = ["ipython"] perf = ["ipython"]
testing = ["flufl-flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"] testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"]
[[package]] [[package]]
name = "ipykernel" name = "ipykernel"
version = "6.15.2" version = "6.15.3"
description = "IPython Kernel for Jupyter" description = "IPython Kernel for Jupyter"
category = "dev" category = "dev"
optional = false optional = false
@ -728,11 +728,11 @@ torch = ">=1.7.1"
torchmetrics = ">=0.4.1" torchmetrics = ">=0.4.1"
[package.extras] [package.extras]
dev = ["ale-py (>=0.7)", "atari-py (==0.2.6)", "atari-py (>=0.2.0,<0.3.0)", "box2d-py (>=2.3.0,<2.4.0)", "check-manifest", "codecov (>=2.1)", "flake8", "gym[atari] (>=0.17.2,<0.20.0)", "isort (>=5.6.4)", "jsonargparse", "matplotlib", "mypy (>=0.790)", "opencv-python-headless", "pillow", "pre-commit (>=1.0)", "pytest (>=6.0)", "pytest-cov (>2.10)", "scikit-learn (>=0.23)", "scipy", "sparseml", "torchvision (>=0.8.2)", "twine (>=3.2)", "wandb"] dev = ["Pillow", "ale-py (>=0.7)", "atari-py (==0.2.6)", "atari-py (>=0.2.0,<0.3.0)", "box2d-py (>=2.3.0,<2.4.0)", "check-manifest", "codecov (>=2.1)", "flake8", "gym[atari] (>=0.17.2,<0.20.0)", "isort (>=5.6.4)", "jsonargparse[signatures]", "matplotlib", "mypy (>=0.790)", "opencv-python-headless", "pre-commit (>=1.0)", "pytest (>=6.0)", "pytest-cov (>2.10)", "scikit-learn (>=0.23)", "scipy", "sparseml", "torchvision (>=0.8.2)", "twine (>=3.2)", "wandb"]
extra = ["atari-py (>=0.2.0,<0.3.0)", "box2d-py (>=2.3.0,<2.4.0)", "gym[atari] (>=0.17.2,<0.20.0)", "matplotlib", "opencv-python-headless", "pillow", "scikit-learn (>=0.23)", "scipy", "torchvision (>=0.8.2)", "wandb"] extra = ["Pillow", "atari-py (>=0.2.0,<0.3.0)", "box2d-py (>=2.3.0,<2.4.0)", "gym[atari] (>=0.17.2,<0.20.0)", "matplotlib", "opencv-python-headless", "scikit-learn (>=0.23)", "scipy", "torchvision (>=0.8.2)", "wandb"]
loggers = ["matplotlib", "scipy", "wandb"] loggers = ["matplotlib", "scipy", "wandb"]
models = ["atari-py (>=0.2.0,<0.3.0)", "box2d-py (>=2.3.0,<2.4.0)", "gym[atari] (>=0.17.2,<0.20.0)", "opencv-python-headless", "pillow", "scikit-learn (>=0.23)", "torchvision (>=0.8.2)"] models = ["Pillow", "atari-py (>=0.2.0,<0.3.0)", "box2d-py (>=2.3.0,<2.4.0)", "gym[atari] (>=0.17.2,<0.20.0)", "opencv-python-headless", "scikit-learn (>=0.23)", "torchvision (>=0.8.2)"]
test = ["ale-py (>=0.7)", "atari-py (==0.2.6)", "check-manifest", "codecov (>=2.1)", "flake8", "isort (>=5.6.4)", "jsonargparse", "mypy (>=0.790)", "pre-commit (>=1.0)", "pytest (>=6.0)", "pytest-cov (>2.10)", "scikit-learn (>=0.23)", "sparseml", "twine (>=3.2)"] test = ["ale-py (>=0.7)", "atari-py (==0.2.6)", "check-manifest", "codecov (>=2.1)", "flake8", "isort (>=5.6.4)", "jsonargparse[signatures]", "mypy (>=0.790)", "pre-commit (>=1.0)", "pytest (>=6.0)", "pytest-cov (>2.10)", "scikit-learn (>=0.23)", "sparseml", "twine (>=3.2)"]
[[package]] [[package]]
name = "markdown" name = "markdown"
@ -1275,10 +1275,10 @@ tqdm = ">=4.57.0"
typing-extensions = ">=4.0.0" typing-extensions = ">=4.0.0"
[package.extras] [package.extras]
all = ["cloudpickle (>=1.3)", "codecov (>=2.1)", "comet-ml (>=3.1.12)", "coverage (>=6.4)", "deepspeed (>=0.6.0)", "fairscale (>=0.4.5)", "fastapi", "gcsfs (>=2021.5.0)", "gym[classic_control] (>=0.17.0)", "hivemind (>=1.0.1)", "horovod (>=0.21.2,!=0.24.0)", "hydra-core (>=1.0.5)", "ipython", "jsonargparse[signatures] (>=4.12.0)", "matplotlib (>3.1)", "mlflow (>=1.0.0)", "mypy (==0.971)", "neptune-client (>=0.10.0)", "omegaconf (>=2.0.5)", "onnxruntime", "pandas (>1.0)", "pre-commit (>=1.0)", "protobuf (<=3.20.1)", "psutil", "pytest (>=7.0)", "pytest-cov", "pytest-forked", "pytest-rerunfailures (>=10.2)", "rich (>=10.14.0,!=10.15.0.a)", "scikit-learn (>0.22.1)", "torchtext (>=0.10)", "torchvision (>=0.10)", "uvicorn", "wandb (>=0.10.22)"] all = ["cloudpickle (>=1.3)", "codecov (>=2.1)", "comet-ml (>=3.1.12)", "coverage (>=6.4)", "deepspeed (>=0.6.0)", "fairscale (>=0.4.5)", "fastapi", "gcsfs (>=2021.5.0)", "gym[classic_control] (>=0.17.0)", "hivemind (>=1.0.1)", "horovod (>=0.21.2,!=0.24.0)", "hydra-core (>=1.0.5)", "ipython[all]", "jsonargparse[signatures] (>=4.12.0)", "matplotlib (>3.1)", "mlflow (>=1.0.0)", "mypy (==0.971)", "neptune-client (>=0.10.0)", "omegaconf (>=2.0.5)", "onnxruntime", "pandas (>1.0)", "pre-commit (>=1.0)", "protobuf (<=3.20.1)", "psutil", "pytest (>=7.0)", "pytest-cov", "pytest-forked", "pytest-rerunfailures (>=10.2)", "rich (>=10.14.0,!=10.15.0.a)", "scikit-learn (>0.22.1)", "torchtext (>=0.10)", "torchvision (>=0.10)", "uvicorn", "wandb (>=0.10.22)"]
deepspeed = ["deepspeed (>=0.6.0)"] deepspeed = ["deepspeed (>=0.6.0)"]
dev = ["cloudpickle (>=1.3)", "codecov (>=2.1)", "comet-ml (>=3.1.12)", "coverage (>=6.4)", "fastapi", "gcsfs (>=2021.5.0)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.12.0)", "matplotlib (>3.1)", "mlflow (>=1.0.0)", "mypy (==0.971)", "neptune-client (>=0.10.0)", "omegaconf (>=2.0.5)", "onnxruntime", "pandas (>1.0)", "pre-commit (>=1.0)", "protobuf (<=3.20.1)", "psutil", "pytest (>=7.0)", "pytest-cov", "pytest-forked", "pytest-rerunfailures (>=10.2)", "rich (>=10.14.0,!=10.15.0.a)", "scikit-learn (>0.22.1)", "torchtext (>=0.10)", "uvicorn", "wandb (>=0.10.22)"] dev = ["cloudpickle (>=1.3)", "codecov (>=2.1)", "comet-ml (>=3.1.12)", "coverage (>=6.4)", "fastapi", "gcsfs (>=2021.5.0)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.12.0)", "matplotlib (>3.1)", "mlflow (>=1.0.0)", "mypy (==0.971)", "neptune-client (>=0.10.0)", "omegaconf (>=2.0.5)", "onnxruntime", "pandas (>1.0)", "pre-commit (>=1.0)", "protobuf (<=3.20.1)", "psutil", "pytest (>=7.0)", "pytest-cov", "pytest-forked", "pytest-rerunfailures (>=10.2)", "rich (>=10.14.0,!=10.15.0.a)", "scikit-learn (>0.22.1)", "torchtext (>=0.10)", "uvicorn", "wandb (>=0.10.22)"]
examples = ["gym[classic_control] (>=0.17.0)", "ipython", "torchvision (>=0.10)"] examples = ["gym[classic_control] (>=0.17.0)", "ipython[all]", "torchvision (>=0.10)"]
extra = ["gcsfs (>=2021.5.0)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.12.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "protobuf (<=3.20.1)", "rich (>=10.14.0,!=10.15.0.a)", "torchtext (>=0.10)"] extra = ["gcsfs (>=2021.5.0)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.12.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "protobuf (<=3.20.1)", "rich (>=10.14.0,!=10.15.0.a)", "torchtext (>=0.10)"]
fairscale = ["fairscale (>=0.4.5)"] fairscale = ["fairscale (>=0.4.5)"]
hivemind = ["hivemind (>=1.0.1)"] hivemind = ["hivemind (>=1.0.1)"]
@ -1421,7 +1421,7 @@ tifffile = ">=2019.7.26"
[package.extras] [package.extras]
data = ["pooch (>=1.3.0)"] data = ["pooch (>=1.3.0)"]
docs = ["cloudpickle (>=0.2.1)", "dask[array] (>=0.15.0,!=2.17.0)", "ipywidgets", "kaleido", "matplotlib (>=3.3)", "myst-parser", "numpydoc (>=1.0)", "pandas (>=0.23.0)", "plotly (>=4.14.0)", "pooch (>=1.3.0)", "pytest-runner", "scikit-learn", "seaborn (>=0.7.1)", "sphinx (>=1.8)", "sphinx-copybutton", "sphinx-gallery (>=0.10.1)", "tifffile (>=2020.5.30)"] docs = ["cloudpickle (>=0.2.1)", "dask[array] (>=0.15.0,!=2.17.0)", "ipywidgets", "kaleido", "matplotlib (>=3.3)", "myst-parser", "numpydoc (>=1.0)", "pandas (>=0.23.0)", "plotly (>=4.14.0)", "pooch (>=1.3.0)", "pytest-runner", "scikit-learn", "seaborn (>=0.7.1)", "sphinx (>=1.8)", "sphinx-copybutton", "sphinx-gallery (>=0.10.1)", "tifffile (>=2020.5.30)"]
optional = ["astropy (>=3.1.2)", "cloudpickle (>=0.2.1)", "dask[array] (>=1.0.0,!=2.17.0)", "matplotlib (>=3.0.3)", "pooch (>=1.3.0)", "pyamg", "qtpy", "simpleitk"] optional = ["SimpleITK", "astropy (>=3.1.2)", "cloudpickle (>=0.2.1)", "dask[array] (>=1.0.0,!=2.17.0)", "matplotlib (>=3.0.3)", "pooch (>=1.3.0)", "pyamg", "qtpy"]
test = ["asv", "codecov", "flake8", "matplotlib (>=3.0.3)", "pooch (>=1.3.0)", "pytest (>=5.2.0)", "pytest-cov (>=2.7.0)", "pytest-faulthandler", "pytest-localserver"] test = ["asv", "codecov", "flake8", "matplotlib (>=3.0.3)", "pooch (>=1.3.0)", "pytest (>=5.2.0)", "pytest-cov (>=2.7.0)", "pytest-faulthandler", "pytest-localserver"]
[[package]] [[package]]
@ -1508,8 +1508,8 @@ python-versions = ">=3.7"
[package.extras] [package.extras]
docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
testing = ["build", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mock", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mock", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
testing-integration = ["build", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
[[package]] [[package]]
name = "setuptools-scm" name = "setuptools-scm"
@ -1899,12 +1899,12 @@ python-versions = ">=3.7"
[package.extras] [package.extras]
docs = ["jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx"] docs = ["jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx"]
testing = ["func-timeout", "jaraco-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = ">=3.8,<3.11" python-versions = ">=3.8,<3.11"
content-hash = "8eba042ca7188b409be6ea812f315b9aa6964791195e21ea8ea66655bb4c4fab" content-hash = "5f210ee2c6b72cc8e0021f082cc9c7a985528a82c437dca40086704faed8d2f6"
[metadata.files] [metadata.files]
absl-py = [ absl-py = [
@ -2357,8 +2357,8 @@ importlib-metadata = [
{file = "importlib_metadata-4.12.0.tar.gz", hash = "sha256:637245b8bab2b6502fcbc752cc4b7a6f6243bb02b31c5c26156ad103d3d45670"}, {file = "importlib_metadata-4.12.0.tar.gz", hash = "sha256:637245b8bab2b6502fcbc752cc4b7a6f6243bb02b31c5c26156ad103d3d45670"},
] ]
ipykernel = [ ipykernel = [
{file = "ipykernel-6.15.2-py3-none-any.whl", hash = "sha256:59183ef833b82c72211aace3fb48fd20eae8e2d0cae475f3d5c39d4a688e81ec"}, {file = "ipykernel-6.15.3-py3-none-any.whl", hash = "sha256:befe3736944b21afec8e832725e9a45f254c8bd9afc40b61d6661c97e45aff5a"},
{file = "ipykernel-6.15.2.tar.gz", hash = "sha256:e7481083b438609c9c8a22d6362e8e1bc6ec94ba0741b666941e634f2d61bdf3"}, {file = "ipykernel-6.15.3.tar.gz", hash = "sha256:b81d57b0e171670844bf29cdc11562b1010d3da87115c4513e0ee660a8368765"},
] ]
ipython = [ ipython = [
{file = "ipython-8.5.0-py3-none-any.whl", hash = "sha256:6f090e29ab8ef8643e521763a4f1f39dc3914db643122b1e9d3328ff2e43ada2"}, {file = "ipython-8.5.0-py3-none-any.whl", hash = "sha256:6f090e29ab8ef8643e521763a4f1f39dc3914db643122b1e9d3328ff2e43ada2"},

View file

@ -23,7 +23,7 @@ wandb = "^0.13.2"
optional = true optional = true
[tool.poetry.group.notebooks.dependencies] [tool.poetry.group.notebooks.dependencies]
ipykernel = "^6.15.2" ipykernel = "^6.15.3"
matplotlib = "^3.5.3" matplotlib = "^3.5.3"
onnx = "^1.12.0" onnx = "^1.12.0"
onnxruntime = "^1.12.1" onnxruntime = "^1.12.1"

View file

@ -164,6 +164,7 @@ class LabeledDataset(Dataset):
# create bboxes from masks (pascal format) # create bboxes from masks (pascal format)
num_objs = len(obj_ids) num_objs = len(obj_ids)
bboxes = [] bboxes = []
labels = []
for i in range(num_objs): for i in range(num_objs):
pos = np.where(masks[i]) pos = np.where(masks[i])
xmin = np.min(pos[1]) xmin = np.min(pos[1])
@ -171,10 +172,11 @@ class LabeledDataset(Dataset):
ymin = np.min(pos[0]) ymin = np.min(pos[0])
ymax = np.max(pos[0]) ymax = np.max(pos[0])
bboxes.append([xmin, ymin, xmax, ymax]) bboxes.append([xmin, ymin, xmax, ymax])
labels.append(2 if mask[(ymax + ymin) // 2, (xmax + xmin) // 2] > 127 else 1)
# convert arrays for albumentations # convert arrays for albumentations
bboxes = torch.as_tensor(bboxes, dtype=torch.int64) bboxes = torch.as_tensor(bboxes, dtype=torch.int64)
labels = torch.ones((num_objs,), dtype=torch.int64) # assume there is only one class (id=1) labels = torch.as_tensor(labels, dtype=torch.int64)
masks = list(np.asarray(masks)) masks = list(np.asarray(masks))
if self.transforms is not None: if self.transforms is not None:

View file

@ -35,7 +35,7 @@ if __name__ == "__main__":
# Create Network # Create Network
module = MRCNNModule( module = MRCNNModule(
n_classes=2, n_classes=3,
) )
# load checkpoint # load checkpoint

View file

@ -1,5 +1,5 @@
DIR_TRAIN_IMG: DIR_TRAIN_IMG:
value: "/media/disk1/lfainsin/TRAIN_prerender/" value: "/media/disk1/lfainsin/TRAIN_prerender_old/"
DIR_VALID_IMG: DIR_VALID_IMG:
value: "/media/disk1/lfainsin/TEST_tmp_mrcnn/" value: "/media/disk1/lfainsin/TEST_tmp_mrcnn/"
# DIR_SPHERE: # DIR_SPHERE: