feat: checkpointing + torchmetrics + ...

Former-commit-id: 781a11f646964567b3bc6831caa5e380748b84e1 [formerly 7ea5f37ecb7fd5c95e5d1b7b1aa899d986ecff2a]
Former-commit-id: 06974fa163a6a8eb881d5981f9c6debe63f1b4bb
This commit is contained in:
Laurent Fainsin 2022-09-05 16:04:52 +02:00
parent c0a52196fd
commit c312513eff
15 changed files with 264 additions and 310 deletions

View file

@ -11,7 +11,7 @@ repos:
- id: check-executables-have-shebangs - id: check-executables-have-shebangs
- id: check-merge-conflict - id: check-merge-conflict
- id: check-symlinks - id: check-symlinks
- id: check-json # - id: check-json
- id: check-toml - id: check-toml
- id: check-yaml - id: check-yaml
- id: debug-statements - id: debug-statements

22
.vscode/launch.json vendored
View file

@ -2,20 +2,28 @@
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
{ {
"name": "Python: Current File", "name": "Train",
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "${file}", "program": "${workspaceFolder}/src/train.py",
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": false,
},
{
"name": "Predict",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/src/predict.py",
"console": "integratedTerminal",
"justMyCode": false,
"args": [ "args": [
"--input", "--input",
"images/test.png", "images/input.png",
"--output", "--output",
"output_onnx.png", "images/output.png",
"--model", "--model",
"good.onnx" "checkpoints/model.onnx"
], ]
"justMyCode": false
} }
] ]
} }

View file

@ -20,5 +20,6 @@
"**/.DS_Store": true, "**/.DS_Store": true,
"**/Thumbs.db": true, "**/Thumbs.db": true,
"**/__pycache__": true, "**/__pycache__": true,
"**/.mypy_cache": true,
} }
} }

106
poetry.lock generated
View file

@ -203,7 +203,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
name = "coloredlogs" name = "coloredlogs"
version = "15.0.1" version = "15.0.1"
description = "Colored terminal output for Python's logging module" description = "Colored terminal output for Python's logging module"
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
@ -213,11 +213,22 @@ humanfriendly = ">=9.1"
[package.extras] [package.extras]
cron = ["capturer (>=2.4)"] cron = ["capturer (>=2.4)"]
[[package]]
name = "commonmark"
version = "0.9.1"
description = "Python parser for the CommonMark Markdown spec"
category = "main"
optional = false
python-versions = "*"
[package.extras]
test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"]
[[package]] [[package]]
name = "cycler" name = "cycler"
version = "0.11.0" version = "0.11.0"
description = "Composable style cycles" description = "Composable style cycles"
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
@ -288,7 +299,7 @@ testing = ["covdefaults (>=2.2)", "coverage (>=6.4.2)", "pytest (>=7.1.2)", "pyt
name = "flatbuffers" name = "flatbuffers"
version = "2.0.7" version = "2.0.7"
description = "The FlatBuffers serialization format for Python" description = "The FlatBuffers serialization format for Python"
category = "dev" category = "main"
optional = false optional = false
python-versions = "*" python-versions = "*"
@ -296,7 +307,7 @@ python-versions = "*"
name = "fonttools" name = "fonttools"
version = "4.37.1" version = "4.37.1"
description = "Tools to manipulate font files" description = "Tools to manipulate font files"
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
@ -432,7 +443,7 @@ protobuf = ["grpcio-tools (>=1.48.1)"]
name = "humanfriendly" name = "humanfriendly"
version = "10.0" version = "10.0"
description = "Human friendly output for text interfaces using Python" description = "Human friendly output for text interfaces using Python"
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
@ -639,7 +650,7 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"]
name = "kiwisolver" name = "kiwisolver"
version = "1.4.4" version = "1.4.4"
description = "A fast implementation of the Cassowary constraint solver" description = "A fast implementation of the Cassowary constraint solver"
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
@ -669,7 +680,7 @@ python-versions = ">=3.7"
name = "matplotlib" name = "matplotlib"
version = "3.5.3" version = "3.5.3"
description = "Python plotting package" description = "Python plotting package"
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
@ -699,7 +710,7 @@ traitlets = "*"
name = "mpmath" name = "mpmath"
version = "1.2.1" version = "1.2.1"
description = "Python library for arbitrary-precision floating-point arithmetic" description = "Python library for arbitrary-precision floating-point arithmetic"
category = "dev" category = "main"
optional = false optional = false
python-versions = "*" python-versions = "*"
@ -825,6 +836,22 @@ packaging = "*"
protobuf = "*" protobuf = "*"
sympy = "*" sympy = "*"
[[package]]
name = "onnxruntime-gpu"
version = "1.12.1"
description = "ONNX Runtime is a runtime accelerator for Machine Learning models"
category = "main"
optional = false
python-versions = "*"
[package.dependencies]
coloredlogs = "*"
flatbuffers = "*"
numpy = ">=1.21.0"
packaging = "*"
protobuf = "*"
sympy = "*"
[[package]] [[package]]
name = "opencv-python-headless" name = "opencv-python-headless"
version = "4.6.0.66" version = "4.6.0.66"
@ -1029,6 +1056,18 @@ python-versions = "*"
[package.dependencies] [package.dependencies]
pyasn1 = ">=0.4.6,<0.5.0" pyasn1 = ">=0.4.6,<0.5.0"
[[package]]
name = "pycocotools"
version = "2.0.4"
description = "Official APIs for the MS-COCO dataset"
category = "main"
optional = false
python-versions = ">=3.5"
[package.dependencies]
matplotlib = ">=2.1.0"
numpy = "*"
[[package]] [[package]]
name = "pycparser" name = "pycparser"
version = "2.21" version = "2.21"
@ -1049,7 +1088,7 @@ python-versions = ">=3.6"
name = "pygments" name = "pygments"
version = "2.13.0" version = "2.13.0"
description = "Pygments is a syntax highlighting package written in Python." description = "Pygments is a syntax highlighting package written in Python."
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
@ -1071,7 +1110,7 @@ diagrams = ["jinja2", "railroad-diagrams"]
name = "pyreadline3" name = "pyreadline3"
version = "3.4.1" version = "3.4.1"
description = "A python implementation of GNU readline." description = "A python implementation of GNU readline."
category = "dev" category = "main"
optional = false optional = false
python-versions = "*" python-versions = "*"
@ -1079,7 +1118,7 @@ python-versions = "*"
name = "python-dateutil" name = "python-dateutil"
version = "2.8.2" version = "2.8.2"
description = "Extensions to the standard Python datetime module" description = "Extensions to the standard Python datetime module"
category = "dev" category = "main"
optional = false optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
@ -1205,6 +1244,22 @@ requests = ">=2.0.0"
[package.extras] [package.extras]
rsa = ["oauthlib[signedtoken] (>=3.0.0)"] rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
[[package]]
name = "rich"
version = "12.5.1"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
category = "main"
optional = false
python-versions = ">=3.6.3,<4.0.0"
[package.dependencies]
commonmark = ">=0.9.0,<0.10.0"
pygments = ">=2.6.0,<3.0.0"
typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""}
[package.extras]
jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"]
[[package]] [[package]]
name = "rsa" name = "rsa"
version = "4.9" version = "4.9"
@ -1318,7 +1373,7 @@ test = ["pytest"]
name = "setuptools-scm" name = "setuptools-scm"
version = "6.4.2" version = "6.4.2"
description = "the blessed package to manage your versions by scm tags" description = "the blessed package to manage your versions by scm tags"
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
@ -1374,7 +1429,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
name = "sympy" name = "sympy"
version = "1.11.1" version = "1.11.1"
description = "Computer algebra system (CAS) in Python" description = "Computer algebra system (CAS) in Python"
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
@ -1460,7 +1515,7 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
name = "tomli" name = "tomli"
version = "2.0.1" version = "2.0.1"
description = "A lil' TOML parser" description = "A lil' TOML parser"
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
@ -1673,7 +1728,7 @@ testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>=
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = ">=3.8,<3.11" python-versions = ">=3.8,<3.11"
content-hash = "41f098cb3b123362c4c88942260ee0fd07c8b175d0f7f4481ab9a6e4dc401d7f" content-hash = "a937d2dadd3250c71c1d8a43a3a4a72ac53e6f3793d856c17a4f4f199de71d0b"
[metadata.files] [metadata.files]
absl-py = [ absl-py = [
@ -1901,6 +1956,10 @@ coloredlogs = [
{file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"}, {file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"},
{file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"}, {file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"},
] ]
commonmark = [
{file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"},
{file = "commonmark-0.9.1.tar.gz", hash = "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60"},
]
cycler = [ cycler = [
{file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"}, {file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
{file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"}, {file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
@ -2483,6 +2542,16 @@ onnxruntime = [
{file = "onnxruntime-1.12.1-cp39-cp39-win32.whl", hash = "sha256:a9954f6ffab4a0a3877a4800d817950a236a6db4901399eec1ea52033f52da94"}, {file = "onnxruntime-1.12.1-cp39-cp39-win32.whl", hash = "sha256:a9954f6ffab4a0a3877a4800d817950a236a6db4901399eec1ea52033f52da94"},
{file = "onnxruntime-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:76bbd92cbcc5b6b0f893565f072e33f921ae3350a77b74fb7c65757e683516c7"}, {file = "onnxruntime-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:76bbd92cbcc5b6b0f893565f072e33f921ae3350a77b74fb7c65757e683516c7"},
] ]
onnxruntime-gpu = [
{file = "onnxruntime_gpu-1.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42b0393c5122ed90fa2eb76192a486261d86e9526ccb78b2a98923c22791d2d1"},
{file = "onnxruntime_gpu-1.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:ecfe97335027e569d4f46725ba89316041e562b8c499690e25e11cfee4601cd1"},
{file = "onnxruntime_gpu-1.12.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2be6f7f5a1ce0bc8471ce42e10eab92cfb19d0748b857edcb5320b5e98311b7"},
{file = "onnxruntime_gpu-1.12.1-cp37-cp37m-win_amd64.whl", hash = "sha256:d73204323aefebe4eecab9fcf76e22b1a00394e3d838c2962a28a27301186b73"},
{file = "onnxruntime_gpu-1.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b37872527d03d3df10756408ca44014bd6ac354a044ab1c4286cd42dc138e518"},
{file = "onnxruntime_gpu-1.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:296bd9733986cb7517d15bef5535c555d3f863963a71e6575e92d2a854aee61d"},
{file = "onnxruntime_gpu-1.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e46d0724ce54c5908c5760037b78de741fbd48962b370c29ebc20e608b30eda"},
{file = "onnxruntime_gpu-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:fd919373be35b9ba54210688265df38ad5e19a530449385c40dab51da407345d"},
]
opencv-python-headless = [ opencv-python-headless = [
{file = "opencv-python-headless-4.6.0.66.tar.gz", hash = "sha256:d5291d7e10aa2c19cab6fd86f0d61af8617290ecd2d7ffcb051e446868d04cc5"}, {file = "opencv-python-headless-4.6.0.66.tar.gz", hash = "sha256:d5291d7e10aa2c19cab6fd86f0d61af8617290ecd2d7ffcb051e446868d04cc5"},
{file = "opencv_python_headless-4.6.0.66-cp36-abi3-macosx_10_15_x86_64.whl", hash = "sha256:21e70f8b0c04098cdf466d27184fe6c3820aaef944a22548db95099959c95889"}, {file = "opencv_python_headless-4.6.0.66-cp36-abi3-macosx_10_15_x86_64.whl", hash = "sha256:21e70f8b0c04098cdf466d27184fe6c3820aaef944a22548db95099959c95889"},
@ -2694,6 +2763,9 @@ pyasn1-modules = [
{file = "pyasn1_modules-0.2.8-py3.6.egg", hash = "sha256:cbac4bc38d117f2a49aeedec4407d23e8866ea4ac27ff2cf7fb3e5b570df19e0"}, {file = "pyasn1_modules-0.2.8-py3.6.egg", hash = "sha256:cbac4bc38d117f2a49aeedec4407d23e8866ea4ac27ff2cf7fb3e5b570df19e0"},
{file = "pyasn1_modules-0.2.8-py3.7.egg", hash = "sha256:c29a5e5cc7a3f05926aff34e097e84f8589cd790ce0ed41b67aed6857b26aafd"}, {file = "pyasn1_modules-0.2.8-py3.7.egg", hash = "sha256:c29a5e5cc7a3f05926aff34e097e84f8589cd790ce0ed41b67aed6857b26aafd"},
] ]
pycocotools = [
{file = "pycocotools-2.0.4.tar.gz", hash = "sha256:2ab586aa389b9657b6d73c2b9a827a3681f8d00f36490c2e8ab05902e3fd9e93"},
]
pycparser = [ pycparser = [
{file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"},
{file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"},
@ -2892,6 +2964,10 @@ requests-oauthlib = [
{file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"}, {file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"},
{file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"}, {file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"},
] ]
rich = [
{file = "rich-12.5.1-py3-none-any.whl", hash = "sha256:2eb4e6894cde1e017976d2975ac210ef515d7548bc595ba20e195fb9628acdeb"},
{file = "rich-12.5.1.tar.gz", hash = "sha256:63a5c5ce3673d3d5fbbf23cd87e11ab84b6b451436f1b7f19ec54b6bc36ed7ca"},
]
rsa = [ rsa = [
{file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
{file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},

View file

@ -14,6 +14,8 @@ torch = "^1.12.1"
torchmetrics = "^0.9.3" torchmetrics = "^0.9.3"
torchvision = "^0.13.1" torchvision = "^0.13.1"
wandb = "^0.13.2" wandb = "^0.13.2"
rich = "^12.5.1"
pycocotools = "^2.0.4"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
black = { extras = ["jupyter"], version = "^22.8.0" } black = { extras = ["jupyter"], version = "^22.8.0" }
@ -22,6 +24,7 @@ isort = "^5.10.1"
matplotlib = "^3.5.3" matplotlib = "^3.5.3"
mypy = "^0.971" mypy = "^0.971"
onnxruntime = "^1.12.1" onnxruntime = "^1.12.1"
onnxruntime-gpu = "^1.12.1"
pre-commit = "^2.20.0" pre-commit = "^2.20.0"
[build-system] [build-system]

View file

@ -1,9 +1,8 @@
import albumentations as A import albumentations as A
import pytorch_lightning as pl import pytorch_lightning as pl
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Subset
import wandb import wandb
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
from .dataset import RealDataset from .dataset import RealDataset
@ -20,17 +19,18 @@ class Spheres(pl.LightningDataModule):
transforms = A.Compose( transforms = A.Compose(
[ [
A.Flip(), A.Flip(),
A.ColorJitter(), # A.ColorJitter(),
A.ToGray(p=0.01), # A.ToGray(p=0.01),
A.GaussianBlur(), # A.GaussianBlur(),
A.MotionBlur(), # A.MotionBlur(),
A.ISONoise(), # A.ISONoise(),
A.ImageCompression(), # A.ImageCompression(),
A.Normalize( # A.Normalize(
mean=[0.485, 0.456, 0.406], # mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225], # std=[0.229, 0.224, 0.225],
max_pixel_value=255, # max_pixel_value=255,
), # [0, 255] -> coco (?) normalized # ), # [0, 255] -> coco (?) normalized
A.ToFloat(max_value=255),
ToTensorV2(), # HWC -> CHW ToTensorV2(), # HWC -> CHW
], ],
bbox_params=A.BboxParams( bbox_params=A.BboxParams(
@ -57,11 +57,12 @@ class Spheres(pl.LightningDataModule):
def val_dataloader(self): def val_dataloader(self):
transforms = A.Compose( transforms = A.Compose(
[ [
A.Normalize( # A.Normalize(
mean=[0.485, 0.456, 0.406], # mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225], # std=[0.229, 0.224, 0.225],
max_pixel_value=255, # max_pixel_value=255,
), # [0, 255] -> [0.0, 1.0] normalized # ), # [0, 255] -> [0.0, 1.0] normalized
A.ToFloat(max_value=255),
ToTensorV2(), # HWC -> CHW ToTensorV2(), # HWC -> CHW
], ],
bbox_params=A.BboxParams( bbox_params=A.BboxParams(

View file

@ -3,16 +3,14 @@
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchvision import torchvision
import wandb
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import ( from torchvision.models.detection.mask_rcnn import (
MaskRCNN_ResNet50_FPN_Weights, MaskRCNN_ResNet50_FPN_Weights,
MaskRCNNPredictor, MaskRCNNPredictor,
) )
import wandb
from utils.coco_eval import CocoEvaluator
from utils.coco_utils import get_coco_api_from_dataset, get_iou_types
def get_model_instance_segmentation(num_classes): def get_model_instance_segmentation(num_classes):
# load an instance segmentation model pre-trained on COCO # load an instance segmentation model pre-trained on COCO
@ -33,11 +31,10 @@ def get_model_instance_segmentation(num_classes):
class MRCNNModule(pl.LightningModule): class MRCNNModule(pl.LightningModule):
def __init__(self, hidden_layer_size, n_classes): def __init__(self, n_classes):
super().__init__() super().__init__()
# Hyperparameters # Hyperparameters
self.hidden_layers_size = hidden_layer_size
self.n_classes = n_classes self.n_classes = n_classes
# log hyperparameters # log hyperparameters
@ -46,10 +43,12 @@ class MRCNNModule(pl.LightningModule):
# Network # Network
self.model = get_model_instance_segmentation(n_classes) self.model = get_model_instance_segmentation(n_classes)
# pycoco evaluator # onnx
self.coco = None self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True)
self.iou_types = get_iou_types(self.model)
self.coco_evaluator = None def forward(self, imgs):
self.model.eval()
return self.model(imgs)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
# unpack batch # unpack batch
@ -67,20 +66,17 @@ class MRCNNModule(pl.LightningModule):
return loss return loss
def on_validation_epoch_start(self): def on_validation_epoch_start(self):
if self.coco is None: self.metric = MeanAveragePrecision(iou_type="bbox")
self.coco = get_coco_api_from_dataset(self.trainer.val_dataloaders[0].dataset)
# init coco evaluator
self.coco_evaluator = CocoEvaluator(self.coco, self.iou_types)
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# unpack batch # unpack batch
images, targets = batch images, targets = batch
# compute metrics using pycocotools preds = self.model(images)
outputs = self.model(images) for pred, target in zip(preds, targets):
res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} pred["masks"] = pred["masks"].squeeze(1).bool()
self.coco_evaluator.update(res) target["masks"] = target["masks"].squeeze(1).bool()
self.metric.update(preds, targets)
# compute validation loss # compute validation loss
self.model.train() self.model.train()
@ -93,48 +89,22 @@ class MRCNNModule(pl.LightningModule):
def validation_epoch_end(self, outputs): def validation_epoch_end(self, outputs):
# log validation loss # log validation loss
loss_dict = {k: torch.stack([d[k] for d in outputs]).mean() for k in outputs[0].keys()} loss_dict = {
k: torch.stack([d[k] for d in outputs]).mean() for k in outputs[0].keys()
} # TODO: update un dict object
self.log_dict(loss_dict) self.log_dict(loss_dict)
# accumulate all predictions # log metrics
self.coco_evaluator.accumulate() metric_dict = self.metric.compute()
self.coco_evaluator.summarize() metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()}
self.log_dict(metric_dict)
YEET = {
"valid,bbox,AP,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[0],
"valid,bbox,AP,IoU=0.50,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[1],
"valid,bbox,AP,IoU=0.75,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[2],
"valid,bbox,AP,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[3],
"valid,bbox,AP,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[4],
"valid,bbox,AP,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[5],
"valid,bbox,AR,IoU=0.50:0.,area=all,maxDets=1": self.coco_evaluator.coco_eval["bbox"].stats[6],
"valid,bbox,AR,IoU=0.50:0.,area=all,maxDets=10": self.coco_evaluator.coco_eval["bbox"].stats[7],
"valid,bbox,AR,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[8],
"valid,bbox,AR,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[9],
"valid,bbox,AR,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[10],
"valid,bbox,AR,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[11],
"valid,segm,AP,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[0],
"valid,segm,AP,IoU=0.50,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[1],
"valid,segm,AP,IoU=0.75,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[2],
"valid,segm,AP,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[3],
"valid,segm,AP,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[4],
"valid,segm,AP,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[5],
"valid,segm,AR,IoU=0.50:0.,area=all,maxDets=1": self.coco_evaluator.coco_eval["segm"].stats[6],
"valid,segm,AR,IoU=0.50:0.,area=all,maxDets=10": self.coco_evaluator.coco_eval["segm"].stats[7],
"valid,segm,AR,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[8],
"valid,segm,AR,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[9],
"valid,segm,AR,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[10],
"valid,segm,AR,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[11],
}
self.log_dict(YEET)
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.SGD( optimizer = torch.optim.Adam(
self.parameters(), self.parameters(),
lr=wandb.config.LEARNING_RATE, lr=wandb.config.LEARNING_RATE,
momentum=wandb.config.MOMENTUM, # momentum=wandb.config.MOMENTUM,
weight_decay=wandb.config.WEIGHT_DECAY, # weight_decay=wandb.config.WEIGHT_DECAY,
) )
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(

View file

@ -1 +1 @@
dbe91cbc0788d93595dac272825aa57412f04d70 1163aa4f284163ecb7d4289c4c5217ea28f63770

View file

@ -1,152 +0,0 @@
"""Pytorch lightning wrapper for model."""
import pytorch_lightning as pl
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import wandb
def get_model_instance_segmentation(num_classes):
# load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# now get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
# and replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
return model
class MRCNNModule(pl.LightningModule):
def __init__(self, hidden_layer_size, n_classes):
super().__init__()
# Hyperparameters
self.hidden_layers_size = hidden_layer_size
self.n_classes = n_classes
# log hyperparameters
self.save_hyperparameters()
# Network
self.model = get_model_instance_segmentation(n_classes)
# onnx
self.example_input_array = torch.randn(1, 3, 512, 512, requires_grad=True)
def forward(self, imgs):
self.model.eval()
return self.model(imgs)
def training_step(self, batch, batch_idx):
# unpack batch
images, targets = batch
# enable train mode
# self.model.train()
# fasterrcnn takes both images and targets for training
loss_dict = self.model(images, targets)
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
loss = sum(loss_dict.values())
# log everything
self.log_dict(loss_dict)
self.log("train/loss", loss)
return {"loss": loss, "log": loss_dict}
# def validation_step(self, batch, batch_idx):
# # unpack batch
# images, targets = batch
# # enable eval mode
# # self.detector.eval()
# # make a prediction
# preds = self.model(images)
# # compute validation loss
# self.val_loss = torch.mean(
# torch.stack(
# [
# self.accuracy(
# target,
# pred["boxes"],
# iou_threshold=0.5,
# )
# for target, pred in zip(targets, preds)
# ],
# )
# )
# return self.val_loss
# def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0):
# """
# The accuracy method is not the one used in the evaluator but very similar
# """
# total_gt = len(src_boxes)
# total_pred = len(pred_boxes)
# if total_gt > 0 and total_pred > 0:
# # Define the matcher and distance matrix based on iou
# matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False)
# match_quality_matrix = box_iou(src_boxes, pred_boxes)
# results = matcher(match_quality_matrix)
# true_positive = torch.count_nonzero(results.unique() != -1)
# matched_elements = results[results > -1]
# # in Matcher, a pred element can be matched only twice
# false_positive = torch.count_nonzero(results == -1) + (
# len(matched_elements) - len(matched_elements.unique())
# )
# false_negative = total_gt - true_positive
# return true_positive / (true_positive + false_positive + false_negative)
# elif total_gt == 0:
# if total_pred > 0:
# return torch.tensor(0.0).cuda()
# else:
# return torch.tensor(1.0).cuda()
# elif total_gt > 0 and total_pred == 0:
# return torch.tensor(0.0).cuda()
def configure_optimizers(self):
optimizer = torch.optim.SGD(
self.parameters(),
lr=wandb.config.LEARNING_RATE,
momentum=wandb.config.MOMENTUM,
weight_decay=wandb.config.WEIGHT_DECAY,
)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
# optimizer,
# T_0=3,
# T_mult=1,
# lr=wandb.config.LEARNING_RATE_MIN,
# verbose=True,
# )
# return {
# "optimizer": optimizer,
# "lr_scheduler": {
# "scheduler": scheduler,
# "monitor": "val_accuracy",
# },
# }
return optimizer

View file

@ -1 +1 @@
df787682f8ee4387834b713d1d48437b94d45f61 ecf0b9ce39e210bc605fd3eab9db8b1215c35fda

70
src/notebooks/test.py Normal file
View file

@ -0,0 +1,70 @@
import torch
from torchmetrics.detection.mean_ap import MeanAveragePrecision
preds = [
dict(
boxes=torch.tensor(
[
[880.0560, 41.7845, 966.9839, 131.3355],
[1421.0029, 682.4420, 1512.7570, 765.2380],
[132.0775, 818.5026, 216.0825, 1020.8573],
]
),
scores=torch.tensor(
[0.9989, 0.9936, 0.0932],
),
labels=torch.tensor(
[1, 1, 1],
),
)
]
target = [
dict(
boxes=torch.tensor(
[[879, 39, 1513, 766]],
),
labels=torch.tensor(
[1],
),
)
]
metric = MeanAveragePrecision()
metric.update(preds, target)
from pprint import pprint
pprint(metric.compute())
# --------------------------------------------------------------------------
preds = [
dict(
boxes=torch.tensor(
[
[880.0560, 41.7845, 1500, 700.3355],
]
),
scores=torch.tensor(
[0.9989],
),
labels=torch.tensor(
[1],
),
)
]
target = [
dict(
boxes=torch.tensor(
[[879, 39, 1513, 766]],
),
labels=torch.tensor(
[1],
),
)
]
metric = MeanAveragePrecision()
metric.update(preds, target)
from pprint import pprint
pprint(metric.compute())

View file

@ -1,43 +1,52 @@
import logging import logging
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar import torch
import wandb
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
import wandb
from data import Spheres from data import Spheres
from mrcnn import MRCNNModule from mrcnn import MRCNNModule
from utils import ArtifactLog, TableLog
if __name__ == "__main__": if __name__ == "__main__":
# setup logging # setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logging.basicConfig(
level=logging.INFO,
format="%(levelname)s: %(message)s",
)
# setup wandb # setup wandb
logger = WandbLogger( logger = WandbLogger(
project="Mask R-CNN", project="Mask R-CNN",
config="wandb.yaml", config="wandb.yaml",
save_dir="/tmp/",
log_model="all",
settings=wandb.Settings( settings=wandb.Settings(
code_dir="./src/", code_dir="./src/",
), ),
) )
# seed random generators # seed random generators
pl.seed_everything(wandb.config.SEED, workers=True) pl.seed_everything(
seed=wandb.config.SEED,
workers=True,
)
# Create Network # Create Network
module = MRCNNModule( module = MRCNNModule(
hidden_layer_size=-1,
n_classes=2, n_classes=2,
) )
# load checkpoint # load checkpoint
# state_dict = torch.load("checkpoints/synth.pth") # module.load_state_dict(torch.load()["state_dict"])
# state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()]) # module.load_from_checkpoint("/tmp/model.ckpt")
# model.load_state_dict(state_dict)
# log gradients and weights regularly # log gradients and weights regularly
logger.watch(module.model, log="all") logger.watch(
model=module.model,
log="all",
)
# Create the dataloaders # Create the dataloaders
datamodule = Spheres() datamodule = Spheres()
@ -52,15 +61,17 @@ if __name__ == "__main__":
logger=logger, logger=logger,
log_every_n_steps=5, log_every_n_steps=5,
val_check_interval=50, val_check_interval=50,
callbacks=[RichProgressBar(), ArtifactLog()], callbacks=[
# callbacks=[RichProgressBar(), ArtifactLog(), TableLog()], ModelCheckpoint(monitor="valid/loss", mode="min"),
RichProgressBar(),
],
# profiler="advanced", # profiler="advanced",
num_sanity_val_steps=3, num_sanity_val_steps=3,
devices=[0], devices=[1],
) )
# actually train the model # actually train the model
trainer.fit(model=module, datamodule=datamodule) trainer.fit(model=module, datamodule=datamodule)
# stop wandb # stop wandb
wandb.run.finish() wandb.run.finish() # type: ignore

View file

@ -1,2 +1,2 @@
from .callback import ArtifactLog, TableLog from .callback import TableLog
from .paste import RandomPaste from .paste import RandomPaste

View file

@ -1,8 +1,5 @@
import numpy as np
import torch
from pytorch_lightning.callbacks import Callback
import wandb import wandb
from pytorch_lightning.callbacks import Callback
columns = [ columns = [
"ID", "ID",
@ -64,34 +61,3 @@ class TableLog(Callback):
) )
} }
) )
class ArtifactLog(Callback):
# def on_fit_start(self, trainer, pl_module):
# self.best = 1
def on_train_epoch_end(self, trainer, pl_module):
# create checkpoint
torch.save(pl_module.state_dict(), "checkpoints/model.pth")
# def on_validation_epoch_start(self, trainer, pl_module):
# self.dices = []
# def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# # unpacking
# metrics, _ = outputs
# self.dices.append(metrics["dice"].cpu())
# def on_validation_epoch_end(self, trainer, pl_module):
# dice = np.mean(self.dices)
# if dice < self.best:
# self.best = dice
# # create checkpoint
# trainer.save_checkpoint("checkpoints/model.ckpt")
# # log artifact
# artifact = wandb.Artifact("ckpt", type="model")
# artifact.add_file("checkpoints/model.ckpt")
# wandb.run.log_artifact(artifact)

View file

@ -1,9 +1,9 @@
DIR_TRAIN_IMG: # DIR_TRAIN_IMG:
value: "/media/disk1/lfainsin/BACKGROUND/" # value: "/media/disk1/lfainsin/BACKGROUND/"
DIR_VALID_IMG: # DIR_VALID_IMG:
value: "/media/disk1/lfainsin/TEST_batched/" # value: "/media/disk1/lfainsin/TEST_batched/"
DIR_SPHERE: # DIR_SPHERE:
value: "/media/disk1/lfainsin/SPHERES/" # value: "/media/disk1/lfainsin/SPHERES/"
N_CHANNELS: N_CHANNELS:
value: 3 value: 3
@ -28,7 +28,7 @@ WORKERS:
value: 16 value: 16
EPOCHS: EPOCHS:
value: 10 value: 50
TRAIN_BATCH_SIZE: TRAIN_BATCH_SIZE:
value: 10 value: 10
VALID_BATCH_SIZE: VALID_BATCH_SIZE:
@ -37,7 +37,7 @@ PREFETCH_FACTOR:
value: 2 value: 2
LEARNING_RATE: LEARNING_RATE:
value: 0.005 value: 0.0005
WEIGHT_DECAY: WEIGHT_DECAY:
value: 0.0005 value: 0.0005
MOMENTUM: MOMENTUM: