mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
feat: checkpointing + torchmetrics + ...
Former-commit-id: 781a11f646964567b3bc6831caa5e380748b84e1 [formerly 7ea5f37ecb7fd5c95e5d1b7b1aa899d986ecff2a] Former-commit-id: 06974fa163a6a8eb881d5981f9c6debe63f1b4bb
This commit is contained in:
parent
c0a52196fd
commit
c312513eff
|
@ -11,7 +11,7 @@ repos:
|
||||||
- id: check-executables-have-shebangs
|
- id: check-executables-have-shebangs
|
||||||
- id: check-merge-conflict
|
- id: check-merge-conflict
|
||||||
- id: check-symlinks
|
- id: check-symlinks
|
||||||
- id: check-json
|
# - id: check-json
|
||||||
- id: check-toml
|
- id: check-toml
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: debug-statements
|
- id: debug-statements
|
||||||
|
|
22
.vscode/launch.json
vendored
22
.vscode/launch.json
vendored
|
@ -2,20 +2,28 @@
|
||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
{
|
{
|
||||||
"name": "Python: Current File",
|
"name": "Train",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "${file}",
|
"program": "${workspaceFolder}/src/train.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Predict",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${workspaceFolder}/src/predict.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false,
|
||||||
"args": [
|
"args": [
|
||||||
"--input",
|
"--input",
|
||||||
"images/test.png",
|
"images/input.png",
|
||||||
"--output",
|
"--output",
|
||||||
"output_onnx.png",
|
"images/output.png",
|
||||||
"--model",
|
"--model",
|
||||||
"good.onnx"
|
"checkpoints/model.onnx"
|
||||||
],
|
]
|
||||||
"justMyCode": false
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
1
.vscode/settings.json
vendored
1
.vscode/settings.json
vendored
|
@ -20,5 +20,6 @@
|
||||||
"**/.DS_Store": true,
|
"**/.DS_Store": true,
|
||||||
"**/Thumbs.db": true,
|
"**/Thumbs.db": true,
|
||||||
"**/__pycache__": true,
|
"**/__pycache__": true,
|
||||||
|
"**/.mypy_cache": true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
106
poetry.lock
generated
106
poetry.lock
generated
|
@ -203,7 +203,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||||
name = "coloredlogs"
|
name = "coloredlogs"
|
||||||
version = "15.0.1"
|
version = "15.0.1"
|
||||||
description = "Colored terminal output for Python's logging module"
|
description = "Colored terminal output for Python's logging module"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||||
|
|
||||||
|
@ -213,11 +213,22 @@ humanfriendly = ">=9.1"
|
||||||
[package.extras]
|
[package.extras]
|
||||||
cron = ["capturer (>=2.4)"]
|
cron = ["capturer (>=2.4)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "commonmark"
|
||||||
|
version = "0.9.1"
|
||||||
|
description = "Python parser for the CommonMark Markdown spec"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cycler"
|
name = "cycler"
|
||||||
version = "0.11.0"
|
version = "0.11.0"
|
||||||
description = "Composable style cycles"
|
description = "Composable style cycles"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.6"
|
python-versions = ">=3.6"
|
||||||
|
|
||||||
|
@ -288,7 +299,7 @@ testing = ["covdefaults (>=2.2)", "coverage (>=6.4.2)", "pytest (>=7.1.2)", "pyt
|
||||||
name = "flatbuffers"
|
name = "flatbuffers"
|
||||||
version = "2.0.7"
|
version = "2.0.7"
|
||||||
description = "The FlatBuffers serialization format for Python"
|
description = "The FlatBuffers serialization format for Python"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
|
|
||||||
|
@ -296,7 +307,7 @@ python-versions = "*"
|
||||||
name = "fonttools"
|
name = "fonttools"
|
||||||
version = "4.37.1"
|
version = "4.37.1"
|
||||||
description = "Tools to manipulate font files"
|
description = "Tools to manipulate font files"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
|
@ -432,7 +443,7 @@ protobuf = ["grpcio-tools (>=1.48.1)"]
|
||||||
name = "humanfriendly"
|
name = "humanfriendly"
|
||||||
version = "10.0"
|
version = "10.0"
|
||||||
description = "Human friendly output for text interfaces using Python"
|
description = "Human friendly output for text interfaces using Python"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||||
|
|
||||||
|
@ -639,7 +650,7 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"]
|
||||||
name = "kiwisolver"
|
name = "kiwisolver"
|
||||||
version = "1.4.4"
|
version = "1.4.4"
|
||||||
description = "A fast implementation of the Cassowary constraint solver"
|
description = "A fast implementation of the Cassowary constraint solver"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
|
@ -669,7 +680,7 @@ python-versions = ">=3.7"
|
||||||
name = "matplotlib"
|
name = "matplotlib"
|
||||||
version = "3.5.3"
|
version = "3.5.3"
|
||||||
description = "Python plotting package"
|
description = "Python plotting package"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
|
@ -699,7 +710,7 @@ traitlets = "*"
|
||||||
name = "mpmath"
|
name = "mpmath"
|
||||||
version = "1.2.1"
|
version = "1.2.1"
|
||||||
description = "Python library for arbitrary-precision floating-point arithmetic"
|
description = "Python library for arbitrary-precision floating-point arithmetic"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
|
|
||||||
|
@ -825,6 +836,22 @@ packaging = "*"
|
||||||
protobuf = "*"
|
protobuf = "*"
|
||||||
sympy = "*"
|
sympy = "*"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "onnxruntime-gpu"
|
||||||
|
version = "1.12.1"
|
||||||
|
description = "ONNX Runtime is a runtime accelerator for Machine Learning models"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
coloredlogs = "*"
|
||||||
|
flatbuffers = "*"
|
||||||
|
numpy = ">=1.21.0"
|
||||||
|
packaging = "*"
|
||||||
|
protobuf = "*"
|
||||||
|
sympy = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "opencv-python-headless"
|
name = "opencv-python-headless"
|
||||||
version = "4.6.0.66"
|
version = "4.6.0.66"
|
||||||
|
@ -1029,6 +1056,18 @@ python-versions = "*"
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
pyasn1 = ">=0.4.6,<0.5.0"
|
pyasn1 = ">=0.4.6,<0.5.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pycocotools"
|
||||||
|
version = "2.0.4"
|
||||||
|
description = "Official APIs for the MS-COCO dataset"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.5"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
matplotlib = ">=2.1.0"
|
||||||
|
numpy = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pycparser"
|
name = "pycparser"
|
||||||
version = "2.21"
|
version = "2.21"
|
||||||
|
@ -1049,7 +1088,7 @@ python-versions = ">=3.6"
|
||||||
name = "pygments"
|
name = "pygments"
|
||||||
version = "2.13.0"
|
version = "2.13.0"
|
||||||
description = "Pygments is a syntax highlighting package written in Python."
|
description = "Pygments is a syntax highlighting package written in Python."
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.6"
|
python-versions = ">=3.6"
|
||||||
|
|
||||||
|
@ -1071,7 +1110,7 @@ diagrams = ["jinja2", "railroad-diagrams"]
|
||||||
name = "pyreadline3"
|
name = "pyreadline3"
|
||||||
version = "3.4.1"
|
version = "3.4.1"
|
||||||
description = "A python implementation of GNU readline."
|
description = "A python implementation of GNU readline."
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
|
|
||||||
|
@ -1079,7 +1118,7 @@ python-versions = "*"
|
||||||
name = "python-dateutil"
|
name = "python-dateutil"
|
||||||
version = "2.8.2"
|
version = "2.8.2"
|
||||||
description = "Extensions to the standard Python datetime module"
|
description = "Extensions to the standard Python datetime module"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
||||||
|
|
||||||
|
@ -1205,6 +1244,22 @@ requests = ">=2.0.0"
|
||||||
[package.extras]
|
[package.extras]
|
||||||
rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
|
rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rich"
|
||||||
|
version = "12.5.1"
|
||||||
|
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6.3,<4.0.0"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
commonmark = ">=0.9.0,<0.10.0"
|
||||||
|
pygments = ">=2.6.0,<3.0.0"
|
||||||
|
typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""}
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rsa"
|
name = "rsa"
|
||||||
version = "4.9"
|
version = "4.9"
|
||||||
|
@ -1318,7 +1373,7 @@ test = ["pytest"]
|
||||||
name = "setuptools-scm"
|
name = "setuptools-scm"
|
||||||
version = "6.4.2"
|
version = "6.4.2"
|
||||||
description = "the blessed package to manage your versions by scm tags"
|
description = "the blessed package to manage your versions by scm tags"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.6"
|
python-versions = ">=3.6"
|
||||||
|
|
||||||
|
@ -1374,7 +1429,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
|
||||||
name = "sympy"
|
name = "sympy"
|
||||||
version = "1.11.1"
|
version = "1.11.1"
|
||||||
description = "Computer algebra system (CAS) in Python"
|
description = "Computer algebra system (CAS) in Python"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
|
|
||||||
|
@ -1460,7 +1515,7 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
|
||||||
name = "tomli"
|
name = "tomli"
|
||||||
version = "2.0.1"
|
version = "2.0.1"
|
||||||
description = "A lil' TOML parser"
|
description = "A lil' TOML parser"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
|
@ -1673,7 +1728,7 @@ testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>=
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = ">=3.8,<3.11"
|
python-versions = ">=3.8,<3.11"
|
||||||
content-hash = "41f098cb3b123362c4c88942260ee0fd07c8b175d0f7f4481ab9a6e4dc401d7f"
|
content-hash = "a937d2dadd3250c71c1d8a43a3a4a72ac53e6f3793d856c17a4f4f199de71d0b"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
absl-py = [
|
absl-py = [
|
||||||
|
@ -1901,6 +1956,10 @@ coloredlogs = [
|
||||||
{file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"},
|
{file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"},
|
||||||
{file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"},
|
{file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"},
|
||||||
]
|
]
|
||||||
|
commonmark = [
|
||||||
|
{file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"},
|
||||||
|
{file = "commonmark-0.9.1.tar.gz", hash = "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60"},
|
||||||
|
]
|
||||||
cycler = [
|
cycler = [
|
||||||
{file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
|
{file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
|
||||||
{file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
|
{file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
|
||||||
|
@ -2483,6 +2542,16 @@ onnxruntime = [
|
||||||
{file = "onnxruntime-1.12.1-cp39-cp39-win32.whl", hash = "sha256:a9954f6ffab4a0a3877a4800d817950a236a6db4901399eec1ea52033f52da94"},
|
{file = "onnxruntime-1.12.1-cp39-cp39-win32.whl", hash = "sha256:a9954f6ffab4a0a3877a4800d817950a236a6db4901399eec1ea52033f52da94"},
|
||||||
{file = "onnxruntime-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:76bbd92cbcc5b6b0f893565f072e33f921ae3350a77b74fb7c65757e683516c7"},
|
{file = "onnxruntime-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:76bbd92cbcc5b6b0f893565f072e33f921ae3350a77b74fb7c65757e683516c7"},
|
||||||
]
|
]
|
||||||
|
onnxruntime-gpu = [
|
||||||
|
{file = "onnxruntime_gpu-1.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42b0393c5122ed90fa2eb76192a486261d86e9526ccb78b2a98923c22791d2d1"},
|
||||||
|
{file = "onnxruntime_gpu-1.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:ecfe97335027e569d4f46725ba89316041e562b8c499690e25e11cfee4601cd1"},
|
||||||
|
{file = "onnxruntime_gpu-1.12.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2be6f7f5a1ce0bc8471ce42e10eab92cfb19d0748b857edcb5320b5e98311b7"},
|
||||||
|
{file = "onnxruntime_gpu-1.12.1-cp37-cp37m-win_amd64.whl", hash = "sha256:d73204323aefebe4eecab9fcf76e22b1a00394e3d838c2962a28a27301186b73"},
|
||||||
|
{file = "onnxruntime_gpu-1.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b37872527d03d3df10756408ca44014bd6ac354a044ab1c4286cd42dc138e518"},
|
||||||
|
{file = "onnxruntime_gpu-1.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:296bd9733986cb7517d15bef5535c555d3f863963a71e6575e92d2a854aee61d"},
|
||||||
|
{file = "onnxruntime_gpu-1.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e46d0724ce54c5908c5760037b78de741fbd48962b370c29ebc20e608b30eda"},
|
||||||
|
{file = "onnxruntime_gpu-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:fd919373be35b9ba54210688265df38ad5e19a530449385c40dab51da407345d"},
|
||||||
|
]
|
||||||
opencv-python-headless = [
|
opencv-python-headless = [
|
||||||
{file = "opencv-python-headless-4.6.0.66.tar.gz", hash = "sha256:d5291d7e10aa2c19cab6fd86f0d61af8617290ecd2d7ffcb051e446868d04cc5"},
|
{file = "opencv-python-headless-4.6.0.66.tar.gz", hash = "sha256:d5291d7e10aa2c19cab6fd86f0d61af8617290ecd2d7ffcb051e446868d04cc5"},
|
||||||
{file = "opencv_python_headless-4.6.0.66-cp36-abi3-macosx_10_15_x86_64.whl", hash = "sha256:21e70f8b0c04098cdf466d27184fe6c3820aaef944a22548db95099959c95889"},
|
{file = "opencv_python_headless-4.6.0.66-cp36-abi3-macosx_10_15_x86_64.whl", hash = "sha256:21e70f8b0c04098cdf466d27184fe6c3820aaef944a22548db95099959c95889"},
|
||||||
|
@ -2694,6 +2763,9 @@ pyasn1-modules = [
|
||||||
{file = "pyasn1_modules-0.2.8-py3.6.egg", hash = "sha256:cbac4bc38d117f2a49aeedec4407d23e8866ea4ac27ff2cf7fb3e5b570df19e0"},
|
{file = "pyasn1_modules-0.2.8-py3.6.egg", hash = "sha256:cbac4bc38d117f2a49aeedec4407d23e8866ea4ac27ff2cf7fb3e5b570df19e0"},
|
||||||
{file = "pyasn1_modules-0.2.8-py3.7.egg", hash = "sha256:c29a5e5cc7a3f05926aff34e097e84f8589cd790ce0ed41b67aed6857b26aafd"},
|
{file = "pyasn1_modules-0.2.8-py3.7.egg", hash = "sha256:c29a5e5cc7a3f05926aff34e097e84f8589cd790ce0ed41b67aed6857b26aafd"},
|
||||||
]
|
]
|
||||||
|
pycocotools = [
|
||||||
|
{file = "pycocotools-2.0.4.tar.gz", hash = "sha256:2ab586aa389b9657b6d73c2b9a827a3681f8d00f36490c2e8ab05902e3fd9e93"},
|
||||||
|
]
|
||||||
pycparser = [
|
pycparser = [
|
||||||
{file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"},
|
{file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"},
|
||||||
{file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"},
|
{file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"},
|
||||||
|
@ -2892,6 +2964,10 @@ requests-oauthlib = [
|
||||||
{file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"},
|
{file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"},
|
||||||
{file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"},
|
{file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"},
|
||||||
]
|
]
|
||||||
|
rich = [
|
||||||
|
{file = "rich-12.5.1-py3-none-any.whl", hash = "sha256:2eb4e6894cde1e017976d2975ac210ef515d7548bc595ba20e195fb9628acdeb"},
|
||||||
|
{file = "rich-12.5.1.tar.gz", hash = "sha256:63a5c5ce3673d3d5fbbf23cd87e11ab84b6b451436f1b7f19ec54b6bc36ed7ca"},
|
||||||
|
]
|
||||||
rsa = [
|
rsa = [
|
||||||
{file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
|
{file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
|
||||||
{file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
|
{file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
|
||||||
|
|
|
@ -14,6 +14,8 @@ torch = "^1.12.1"
|
||||||
torchmetrics = "^0.9.3"
|
torchmetrics = "^0.9.3"
|
||||||
torchvision = "^0.13.1"
|
torchvision = "^0.13.1"
|
||||||
wandb = "^0.13.2"
|
wandb = "^0.13.2"
|
||||||
|
rich = "^12.5.1"
|
||||||
|
pycocotools = "^2.0.4"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
black = { extras = ["jupyter"], version = "^22.8.0" }
|
black = { extras = ["jupyter"], version = "^22.8.0" }
|
||||||
|
@ -22,6 +24,7 @@ isort = "^5.10.1"
|
||||||
matplotlib = "^3.5.3"
|
matplotlib = "^3.5.3"
|
||||||
mypy = "^0.971"
|
mypy = "^0.971"
|
||||||
onnxruntime = "^1.12.1"
|
onnxruntime = "^1.12.1"
|
||||||
|
onnxruntime-gpu = "^1.12.1"
|
||||||
pre-commit = "^2.20.0"
|
pre-commit = "^2.20.0"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from albumentations.pytorch import ToTensorV2
|
|
||||||
from torch.utils.data import DataLoader, Subset
|
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
from albumentations.pytorch import ToTensorV2
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .dataset import RealDataset
|
from .dataset import RealDataset
|
||||||
|
|
||||||
|
@ -20,17 +19,18 @@ class Spheres(pl.LightningDataModule):
|
||||||
transforms = A.Compose(
|
transforms = A.Compose(
|
||||||
[
|
[
|
||||||
A.Flip(),
|
A.Flip(),
|
||||||
A.ColorJitter(),
|
# A.ColorJitter(),
|
||||||
A.ToGray(p=0.01),
|
# A.ToGray(p=0.01),
|
||||||
A.GaussianBlur(),
|
# A.GaussianBlur(),
|
||||||
A.MotionBlur(),
|
# A.MotionBlur(),
|
||||||
A.ISONoise(),
|
# A.ISONoise(),
|
||||||
A.ImageCompression(),
|
# A.ImageCompression(),
|
||||||
A.Normalize(
|
# A.Normalize(
|
||||||
mean=[0.485, 0.456, 0.406],
|
# mean=[0.485, 0.456, 0.406],
|
||||||
std=[0.229, 0.224, 0.225],
|
# std=[0.229, 0.224, 0.225],
|
||||||
max_pixel_value=255,
|
# max_pixel_value=255,
|
||||||
), # [0, 255] -> coco (?) normalized
|
# ), # [0, 255] -> coco (?) normalized
|
||||||
|
A.ToFloat(max_value=255),
|
||||||
ToTensorV2(), # HWC -> CHW
|
ToTensorV2(), # HWC -> CHW
|
||||||
],
|
],
|
||||||
bbox_params=A.BboxParams(
|
bbox_params=A.BboxParams(
|
||||||
|
@ -57,11 +57,12 @@ class Spheres(pl.LightningDataModule):
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
transforms = A.Compose(
|
transforms = A.Compose(
|
||||||
[
|
[
|
||||||
A.Normalize(
|
# A.Normalize(
|
||||||
mean=[0.485, 0.456, 0.406],
|
# mean=[0.485, 0.456, 0.406],
|
||||||
std=[0.229, 0.224, 0.225],
|
# std=[0.229, 0.224, 0.225],
|
||||||
max_pixel_value=255,
|
# max_pixel_value=255,
|
||||||
), # [0, 255] -> [0.0, 1.0] normalized
|
# ), # [0, 255] -> [0.0, 1.0] normalized
|
||||||
|
A.ToFloat(max_value=255),
|
||||||
ToTensorV2(), # HWC -> CHW
|
ToTensorV2(), # HWC -> CHW
|
||||||
],
|
],
|
||||||
bbox_params=A.BboxParams(
|
bbox_params=A.BboxParams(
|
||||||
|
|
|
@ -3,16 +3,14 @@
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
|
import wandb
|
||||||
|
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||||
from torchvision.models.detection.mask_rcnn import (
|
from torchvision.models.detection.mask_rcnn import (
|
||||||
MaskRCNN_ResNet50_FPN_Weights,
|
MaskRCNN_ResNet50_FPN_Weights,
|
||||||
MaskRCNNPredictor,
|
MaskRCNNPredictor,
|
||||||
)
|
)
|
||||||
|
|
||||||
import wandb
|
|
||||||
from utils.coco_eval import CocoEvaluator
|
|
||||||
from utils.coco_utils import get_coco_api_from_dataset, get_iou_types
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_instance_segmentation(num_classes):
|
def get_model_instance_segmentation(num_classes):
|
||||||
# load an instance segmentation model pre-trained on COCO
|
# load an instance segmentation model pre-trained on COCO
|
||||||
|
@ -33,11 +31,10 @@ def get_model_instance_segmentation(num_classes):
|
||||||
|
|
||||||
|
|
||||||
class MRCNNModule(pl.LightningModule):
|
class MRCNNModule(pl.LightningModule):
|
||||||
def __init__(self, hidden_layer_size, n_classes):
|
def __init__(self, n_classes):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
self.hidden_layers_size = hidden_layer_size
|
|
||||||
self.n_classes = n_classes
|
self.n_classes = n_classes
|
||||||
|
|
||||||
# log hyperparameters
|
# log hyperparameters
|
||||||
|
@ -46,10 +43,12 @@ class MRCNNModule(pl.LightningModule):
|
||||||
# Network
|
# Network
|
||||||
self.model = get_model_instance_segmentation(n_classes)
|
self.model = get_model_instance_segmentation(n_classes)
|
||||||
|
|
||||||
# pycoco evaluator
|
# onnx
|
||||||
self.coco = None
|
self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True)
|
||||||
self.iou_types = get_iou_types(self.model)
|
|
||||||
self.coco_evaluator = None
|
def forward(self, imgs):
|
||||||
|
self.model.eval()
|
||||||
|
return self.model(imgs)
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
# unpack batch
|
# unpack batch
|
||||||
|
@ -67,20 +66,17 @@ class MRCNNModule(pl.LightningModule):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def on_validation_epoch_start(self):
|
def on_validation_epoch_start(self):
|
||||||
if self.coco is None:
|
self.metric = MeanAveragePrecision(iou_type="bbox")
|
||||||
self.coco = get_coco_api_from_dataset(self.trainer.val_dataloaders[0].dataset)
|
|
||||||
|
|
||||||
# init coco evaluator
|
|
||||||
self.coco_evaluator = CocoEvaluator(self.coco, self.iou_types)
|
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
# unpack batch
|
# unpack batch
|
||||||
images, targets = batch
|
images, targets = batch
|
||||||
|
|
||||||
# compute metrics using pycocotools
|
preds = self.model(images)
|
||||||
outputs = self.model(images)
|
for pred, target in zip(preds, targets):
|
||||||
res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
|
pred["masks"] = pred["masks"].squeeze(1).bool()
|
||||||
self.coco_evaluator.update(res)
|
target["masks"] = target["masks"].squeeze(1).bool()
|
||||||
|
self.metric.update(preds, targets)
|
||||||
|
|
||||||
# compute validation loss
|
# compute validation loss
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
@ -93,48 +89,22 @@ class MRCNNModule(pl.LightningModule):
|
||||||
|
|
||||||
def validation_epoch_end(self, outputs):
|
def validation_epoch_end(self, outputs):
|
||||||
# log validation loss
|
# log validation loss
|
||||||
loss_dict = {k: torch.stack([d[k] for d in outputs]).mean() for k in outputs[0].keys()}
|
loss_dict = {
|
||||||
|
k: torch.stack([d[k] for d in outputs]).mean() for k in outputs[0].keys()
|
||||||
|
} # TODO: update un dict object
|
||||||
self.log_dict(loss_dict)
|
self.log_dict(loss_dict)
|
||||||
|
|
||||||
# accumulate all predictions
|
# log metrics
|
||||||
self.coco_evaluator.accumulate()
|
metric_dict = self.metric.compute()
|
||||||
self.coco_evaluator.summarize()
|
metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()}
|
||||||
|
self.log_dict(metric_dict)
|
||||||
YEET = {
|
|
||||||
"valid,bbox,AP,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[0],
|
|
||||||
"valid,bbox,AP,IoU=0.50,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[1],
|
|
||||||
"valid,bbox,AP,IoU=0.75,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[2],
|
|
||||||
"valid,bbox,AP,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[3],
|
|
||||||
"valid,bbox,AP,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[4],
|
|
||||||
"valid,bbox,AP,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[5],
|
|
||||||
"valid,bbox,AR,IoU=0.50:0.,area=all,maxDets=1": self.coco_evaluator.coco_eval["bbox"].stats[6],
|
|
||||||
"valid,bbox,AR,IoU=0.50:0.,area=all,maxDets=10": self.coco_evaluator.coco_eval["bbox"].stats[7],
|
|
||||||
"valid,bbox,AR,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[8],
|
|
||||||
"valid,bbox,AR,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[9],
|
|
||||||
"valid,bbox,AR,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[10],
|
|
||||||
"valid,bbox,AR,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[11],
|
|
||||||
"valid,segm,AP,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[0],
|
|
||||||
"valid,segm,AP,IoU=0.50,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[1],
|
|
||||||
"valid,segm,AP,IoU=0.75,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[2],
|
|
||||||
"valid,segm,AP,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[3],
|
|
||||||
"valid,segm,AP,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[4],
|
|
||||||
"valid,segm,AP,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[5],
|
|
||||||
"valid,segm,AR,IoU=0.50:0.,area=all,maxDets=1": self.coco_evaluator.coco_eval["segm"].stats[6],
|
|
||||||
"valid,segm,AR,IoU=0.50:0.,area=all,maxDets=10": self.coco_evaluator.coco_eval["segm"].stats[7],
|
|
||||||
"valid,segm,AR,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[8],
|
|
||||||
"valid,segm,AR,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[9],
|
|
||||||
"valid,segm,AR,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[10],
|
|
||||||
"valid,segm,AR,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[11],
|
|
||||||
}
|
|
||||||
|
|
||||||
self.log_dict(YEET)
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.SGD(
|
optimizer = torch.optim.Adam(
|
||||||
self.parameters(),
|
self.parameters(),
|
||||||
lr=wandb.config.LEARNING_RATE,
|
lr=wandb.config.LEARNING_RATE,
|
||||||
momentum=wandb.config.MOMENTUM,
|
# momentum=wandb.config.MOMENTUM,
|
||||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
# weight_decay=wandb.config.WEIGHT_DECAY,
|
||||||
)
|
)
|
||||||
|
|
||||||
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
dbe91cbc0788d93595dac272825aa57412f04d70
|
1163aa4f284163ecb7d4289c4c5217ea28f63770
|
|
@ -1,152 +0,0 @@
|
||||||
"""Pytorch lightning wrapper for model."""
|
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
||||||
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
|
||||||
|
|
||||||
import wandb
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_instance_segmentation(num_classes):
|
|
||||||
# load an instance segmentation model pre-trained on COCO
|
|
||||||
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
|
|
||||||
|
|
||||||
# get number of input features for the classifier
|
|
||||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
|
||||||
# replace the pre-trained head with a new one
|
|
||||||
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
|
||||||
|
|
||||||
# now get the number of input features for the mask classifier
|
|
||||||
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
|
|
||||||
hidden_layer = 256
|
|
||||||
# and replace the mask predictor with a new one
|
|
||||||
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class MRCNNModule(pl.LightningModule):
|
|
||||||
def __init__(self, hidden_layer_size, n_classes):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# Hyperparameters
|
|
||||||
self.hidden_layers_size = hidden_layer_size
|
|
||||||
self.n_classes = n_classes
|
|
||||||
|
|
||||||
# log hyperparameters
|
|
||||||
self.save_hyperparameters()
|
|
||||||
|
|
||||||
# Network
|
|
||||||
self.model = get_model_instance_segmentation(n_classes)
|
|
||||||
|
|
||||||
# onnx
|
|
||||||
self.example_input_array = torch.randn(1, 3, 512, 512, requires_grad=True)
|
|
||||||
|
|
||||||
def forward(self, imgs):
|
|
||||||
self.model.eval()
|
|
||||||
return self.model(imgs)
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
|
||||||
# unpack batch
|
|
||||||
images, targets = batch
|
|
||||||
|
|
||||||
# enable train mode
|
|
||||||
# self.model.train()
|
|
||||||
|
|
||||||
# fasterrcnn takes both images and targets for training
|
|
||||||
loss_dict = self.model(images, targets)
|
|
||||||
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
|
|
||||||
loss = sum(loss_dict.values())
|
|
||||||
|
|
||||||
# log everything
|
|
||||||
self.log_dict(loss_dict)
|
|
||||||
self.log("train/loss", loss)
|
|
||||||
|
|
||||||
return {"loss": loss, "log": loss_dict}
|
|
||||||
|
|
||||||
# def validation_step(self, batch, batch_idx):
|
|
||||||
# # unpack batch
|
|
||||||
# images, targets = batch
|
|
||||||
|
|
||||||
# # enable eval mode
|
|
||||||
# # self.detector.eval()
|
|
||||||
|
|
||||||
# # make a prediction
|
|
||||||
# preds = self.model(images)
|
|
||||||
|
|
||||||
# # compute validation loss
|
|
||||||
# self.val_loss = torch.mean(
|
|
||||||
# torch.stack(
|
|
||||||
# [
|
|
||||||
# self.accuracy(
|
|
||||||
# target,
|
|
||||||
# pred["boxes"],
|
|
||||||
# iou_threshold=0.5,
|
|
||||||
# )
|
|
||||||
# for target, pred in zip(targets, preds)
|
|
||||||
# ],
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
|
|
||||||
# return self.val_loss
|
|
||||||
|
|
||||||
# def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0):
|
|
||||||
# """
|
|
||||||
# The accuracy method is not the one used in the evaluator but very similar
|
|
||||||
# """
|
|
||||||
# total_gt = len(src_boxes)
|
|
||||||
# total_pred = len(pred_boxes)
|
|
||||||
# if total_gt > 0 and total_pred > 0:
|
|
||||||
|
|
||||||
# # Define the matcher and distance matrix based on iou
|
|
||||||
# matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False)
|
|
||||||
# match_quality_matrix = box_iou(src_boxes, pred_boxes)
|
|
||||||
|
|
||||||
# results = matcher(match_quality_matrix)
|
|
||||||
|
|
||||||
# true_positive = torch.count_nonzero(results.unique() != -1)
|
|
||||||
# matched_elements = results[results > -1]
|
|
||||||
|
|
||||||
# # in Matcher, a pred element can be matched only twice
|
|
||||||
# false_positive = torch.count_nonzero(results == -1) + (
|
|
||||||
# len(matched_elements) - len(matched_elements.unique())
|
|
||||||
# )
|
|
||||||
# false_negative = total_gt - true_positive
|
|
||||||
|
|
||||||
# return true_positive / (true_positive + false_positive + false_negative)
|
|
||||||
|
|
||||||
# elif total_gt == 0:
|
|
||||||
# if total_pred > 0:
|
|
||||||
# return torch.tensor(0.0).cuda()
|
|
||||||
# else:
|
|
||||||
# return torch.tensor(1.0).cuda()
|
|
||||||
# elif total_gt > 0 and total_pred == 0:
|
|
||||||
# return torch.tensor(0.0).cuda()
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = torch.optim.SGD(
|
|
||||||
self.parameters(),
|
|
||||||
lr=wandb.config.LEARNING_RATE,
|
|
||||||
momentum=wandb.config.MOMENTUM,
|
|
||||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
|
||||||
)
|
|
||||||
|
|
||||||
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
|
||||||
# optimizer,
|
|
||||||
# T_0=3,
|
|
||||||
# T_mult=1,
|
|
||||||
# lr=wandb.config.LEARNING_RATE_MIN,
|
|
||||||
# verbose=True,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# return {
|
|
||||||
# "optimizer": optimizer,
|
|
||||||
# "lr_scheduler": {
|
|
||||||
# "scheduler": scheduler,
|
|
||||||
# "monitor": "val_accuracy",
|
|
||||||
# },
|
|
||||||
# }
|
|
||||||
|
|
||||||
return optimizer
|
|
|
@ -1 +1 @@
|
||||||
df787682f8ee4387834b713d1d48437b94d45f61
|
ecf0b9ce39e210bc605fd3eab9db8b1215c35fda
|
70
src/notebooks/test.py
Normal file
70
src/notebooks/test.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
import torch
|
||||||
|
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||||||
|
|
||||||
|
preds = [
|
||||||
|
dict(
|
||||||
|
boxes=torch.tensor(
|
||||||
|
[
|
||||||
|
[880.0560, 41.7845, 966.9839, 131.3355],
|
||||||
|
[1421.0029, 682.4420, 1512.7570, 765.2380],
|
||||||
|
[132.0775, 818.5026, 216.0825, 1020.8573],
|
||||||
|
]
|
||||||
|
),
|
||||||
|
scores=torch.tensor(
|
||||||
|
[0.9989, 0.9936, 0.0932],
|
||||||
|
),
|
||||||
|
labels=torch.tensor(
|
||||||
|
[1, 1, 1],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
target = [
|
||||||
|
dict(
|
||||||
|
boxes=torch.tensor(
|
||||||
|
[[879, 39, 1513, 766]],
|
||||||
|
),
|
||||||
|
labels=torch.tensor(
|
||||||
|
[1],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
metric = MeanAveragePrecision()
|
||||||
|
metric.update(preds, target)
|
||||||
|
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
pprint(metric.compute())
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
preds = [
|
||||||
|
dict(
|
||||||
|
boxes=torch.tensor(
|
||||||
|
[
|
||||||
|
[880.0560, 41.7845, 1500, 700.3355],
|
||||||
|
]
|
||||||
|
),
|
||||||
|
scores=torch.tensor(
|
||||||
|
[0.9989],
|
||||||
|
),
|
||||||
|
labels=torch.tensor(
|
||||||
|
[1],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
target = [
|
||||||
|
dict(
|
||||||
|
boxes=torch.tensor(
|
||||||
|
[[879, 39, 1513, 766]],
|
||||||
|
),
|
||||||
|
labels=torch.tensor(
|
||||||
|
[1],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
metric = MeanAveragePrecision()
|
||||||
|
metric.update(preds, target)
|
||||||
|
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
pprint(metric.compute())
|
39
src/train.py
39
src/train.py
|
@ -1,43 +1,52 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.callbacks import RichProgressBar
|
import torch
|
||||||
|
import wandb
|
||||||
|
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
import wandb
|
|
||||||
from data import Spheres
|
from data import Spheres
|
||||||
from mrcnn import MRCNNModule
|
from mrcnn import MRCNNModule
|
||||||
from utils import ArtifactLog, TableLog
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# setup logging
|
# setup logging
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(levelname)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
# setup wandb
|
# setup wandb
|
||||||
logger = WandbLogger(
|
logger = WandbLogger(
|
||||||
project="Mask R-CNN",
|
project="Mask R-CNN",
|
||||||
config="wandb.yaml",
|
config="wandb.yaml",
|
||||||
|
save_dir="/tmp/",
|
||||||
|
log_model="all",
|
||||||
settings=wandb.Settings(
|
settings=wandb.Settings(
|
||||||
code_dir="./src/",
|
code_dir="./src/",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# seed random generators
|
# seed random generators
|
||||||
pl.seed_everything(wandb.config.SEED, workers=True)
|
pl.seed_everything(
|
||||||
|
seed=wandb.config.SEED,
|
||||||
|
workers=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Create Network
|
# Create Network
|
||||||
module = MRCNNModule(
|
module = MRCNNModule(
|
||||||
hidden_layer_size=-1,
|
|
||||||
n_classes=2,
|
n_classes=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# load checkpoint
|
# load checkpoint
|
||||||
# state_dict = torch.load("checkpoints/synth.pth")
|
# module.load_state_dict(torch.load()["state_dict"])
|
||||||
# state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()])
|
# module.load_from_checkpoint("/tmp/model.ckpt")
|
||||||
# model.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
# log gradients and weights regularly
|
# log gradients and weights regularly
|
||||||
logger.watch(module.model, log="all")
|
logger.watch(
|
||||||
|
model=module.model,
|
||||||
|
log="all",
|
||||||
|
)
|
||||||
|
|
||||||
# Create the dataloaders
|
# Create the dataloaders
|
||||||
datamodule = Spheres()
|
datamodule = Spheres()
|
||||||
|
@ -52,15 +61,17 @@ if __name__ == "__main__":
|
||||||
logger=logger,
|
logger=logger,
|
||||||
log_every_n_steps=5,
|
log_every_n_steps=5,
|
||||||
val_check_interval=50,
|
val_check_interval=50,
|
||||||
callbacks=[RichProgressBar(), ArtifactLog()],
|
callbacks=[
|
||||||
# callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
|
ModelCheckpoint(monitor="valid/loss", mode="min"),
|
||||||
|
RichProgressBar(),
|
||||||
|
],
|
||||||
# profiler="advanced",
|
# profiler="advanced",
|
||||||
num_sanity_val_steps=3,
|
num_sanity_val_steps=3,
|
||||||
devices=[0],
|
devices=[1],
|
||||||
)
|
)
|
||||||
|
|
||||||
# actually train the model
|
# actually train the model
|
||||||
trainer.fit(model=module, datamodule=datamodule)
|
trainer.fit(model=module, datamodule=datamodule)
|
||||||
|
|
||||||
# stop wandb
|
# stop wandb
|
||||||
wandb.run.finish()
|
wandb.run.finish() # type: ignore
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
from .callback import ArtifactLog, TableLog
|
from .callback import TableLog
|
||||||
from .paste import RandomPaste
|
from .paste import RandomPaste
|
||||||
|
|
|
@ -1,8 +1,5 @@
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from pytorch_lightning.callbacks import Callback
|
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
from pytorch_lightning.callbacks import Callback
|
||||||
|
|
||||||
columns = [
|
columns = [
|
||||||
"ID",
|
"ID",
|
||||||
|
@ -64,34 +61,3 @@ class TableLog(Callback):
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ArtifactLog(Callback):
|
|
||||||
# def on_fit_start(self, trainer, pl_module):
|
|
||||||
# self.best = 1
|
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module):
|
|
||||||
# create checkpoint
|
|
||||||
torch.save(pl_module.state_dict(), "checkpoints/model.pth")
|
|
||||||
|
|
||||||
# def on_validation_epoch_start(self, trainer, pl_module):
|
|
||||||
# self.dices = []
|
|
||||||
|
|
||||||
# def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
|
||||||
# # unpacking
|
|
||||||
# metrics, _ = outputs
|
|
||||||
# self.dices.append(metrics["dice"].cpu())
|
|
||||||
|
|
||||||
# def on_validation_epoch_end(self, trainer, pl_module):
|
|
||||||
# dice = np.mean(self.dices)
|
|
||||||
|
|
||||||
# if dice < self.best:
|
|
||||||
# self.best = dice
|
|
||||||
|
|
||||||
# # create checkpoint
|
|
||||||
# trainer.save_checkpoint("checkpoints/model.ckpt")
|
|
||||||
|
|
||||||
# # log artifact
|
|
||||||
# artifact = wandb.Artifact("ckpt", type="model")
|
|
||||||
# artifact.add_file("checkpoints/model.ckpt")
|
|
||||||
# wandb.run.log_artifact(artifact)
|
|
||||||
|
|
16
wandb.yaml
16
wandb.yaml
|
@ -1,9 +1,9 @@
|
||||||
DIR_TRAIN_IMG:
|
# DIR_TRAIN_IMG:
|
||||||
value: "/media/disk1/lfainsin/BACKGROUND/"
|
# value: "/media/disk1/lfainsin/BACKGROUND/"
|
||||||
DIR_VALID_IMG:
|
# DIR_VALID_IMG:
|
||||||
value: "/media/disk1/lfainsin/TEST_batched/"
|
# value: "/media/disk1/lfainsin/TEST_batched/"
|
||||||
DIR_SPHERE:
|
# DIR_SPHERE:
|
||||||
value: "/media/disk1/lfainsin/SPHERES/"
|
# value: "/media/disk1/lfainsin/SPHERES/"
|
||||||
|
|
||||||
N_CHANNELS:
|
N_CHANNELS:
|
||||||
value: 3
|
value: 3
|
||||||
|
@ -28,7 +28,7 @@ WORKERS:
|
||||||
value: 16
|
value: 16
|
||||||
|
|
||||||
EPOCHS:
|
EPOCHS:
|
||||||
value: 10
|
value: 50
|
||||||
TRAIN_BATCH_SIZE:
|
TRAIN_BATCH_SIZE:
|
||||||
value: 10
|
value: 10
|
||||||
VALID_BATCH_SIZE:
|
VALID_BATCH_SIZE:
|
||||||
|
@ -37,7 +37,7 @@ PREFETCH_FACTOR:
|
||||||
value: 2
|
value: 2
|
||||||
|
|
||||||
LEARNING_RATE:
|
LEARNING_RATE:
|
||||||
value: 0.005
|
value: 0.0005
|
||||||
WEIGHT_DECAY:
|
WEIGHT_DECAY:
|
||||||
value: 0.0005
|
value: 0.0005
|
||||||
MOMENTUM:
|
MOMENTUM:
|
||||||
|
|
Loading…
Reference in a new issue