feat: checkpointing + torchmetrics + ...

Former-commit-id: 781a11f646964567b3bc6831caa5e380748b84e1 [formerly 7ea5f37ecb7fd5c95e5d1b7b1aa899d986ecff2a]
Former-commit-id: 06974fa163a6a8eb881d5981f9c6debe63f1b4bb
This commit is contained in:
Laurent Fainsin 2022-09-05 16:04:52 +02:00
parent c0a52196fd
commit c312513eff
15 changed files with 264 additions and 310 deletions

View file

@ -11,7 +11,7 @@ repos:
- id: check-executables-have-shebangs
- id: check-merge-conflict
- id: check-symlinks
- id: check-json
# - id: check-json
- id: check-toml
- id: check-yaml
- id: debug-statements

22
.vscode/launch.json vendored
View file

@ -2,20 +2,28 @@
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"name": "Train",
"type": "python",
"request": "launch",
"program": "${file}",
"program": "${workspaceFolder}/src/train.py",
"console": "integratedTerminal",
"justMyCode": false,
},
{
"name": "Predict",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/src/predict.py",
"console": "integratedTerminal",
"justMyCode": false,
"args": [
"--input",
"images/test.png",
"images/input.png",
"--output",
"output_onnx.png",
"images/output.png",
"--model",
"good.onnx"
],
"justMyCode": false
"checkpoints/model.onnx"
]
}
]
}

View file

@ -20,5 +20,6 @@
"**/.DS_Store": true,
"**/Thumbs.db": true,
"**/__pycache__": true,
"**/.mypy_cache": true,
}
}

106
poetry.lock generated
View file

@ -203,7 +203,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
name = "coloredlogs"
version = "15.0.1"
description = "Colored terminal output for Python's logging module"
category = "dev"
category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
@ -213,11 +213,22 @@ humanfriendly = ">=9.1"
[package.extras]
cron = ["capturer (>=2.4)"]
[[package]]
name = "commonmark"
version = "0.9.1"
description = "Python parser for the CommonMark Markdown spec"
category = "main"
optional = false
python-versions = "*"
[package.extras]
test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"]
[[package]]
name = "cycler"
version = "0.11.0"
description = "Composable style cycles"
category = "dev"
category = "main"
optional = false
python-versions = ">=3.6"
@ -288,7 +299,7 @@ testing = ["covdefaults (>=2.2)", "coverage (>=6.4.2)", "pytest (>=7.1.2)", "pyt
name = "flatbuffers"
version = "2.0.7"
description = "The FlatBuffers serialization format for Python"
category = "dev"
category = "main"
optional = false
python-versions = "*"
@ -296,7 +307,7 @@ python-versions = "*"
name = "fonttools"
version = "4.37.1"
description = "Tools to manipulate font files"
category = "dev"
category = "main"
optional = false
python-versions = ">=3.7"
@ -432,7 +443,7 @@ protobuf = ["grpcio-tools (>=1.48.1)"]
name = "humanfriendly"
version = "10.0"
description = "Human friendly output for text interfaces using Python"
category = "dev"
category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
@ -639,7 +650,7 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"]
name = "kiwisolver"
version = "1.4.4"
description = "A fast implementation of the Cassowary constraint solver"
category = "dev"
category = "main"
optional = false
python-versions = ">=3.7"
@ -669,7 +680,7 @@ python-versions = ">=3.7"
name = "matplotlib"
version = "3.5.3"
description = "Python plotting package"
category = "dev"
category = "main"
optional = false
python-versions = ">=3.7"
@ -699,7 +710,7 @@ traitlets = "*"
name = "mpmath"
version = "1.2.1"
description = "Python library for arbitrary-precision floating-point arithmetic"
category = "dev"
category = "main"
optional = false
python-versions = "*"
@ -825,6 +836,22 @@ packaging = "*"
protobuf = "*"
sympy = "*"
[[package]]
name = "onnxruntime-gpu"
version = "1.12.1"
description = "ONNX Runtime is a runtime accelerator for Machine Learning models"
category = "main"
optional = false
python-versions = "*"
[package.dependencies]
coloredlogs = "*"
flatbuffers = "*"
numpy = ">=1.21.0"
packaging = "*"
protobuf = "*"
sympy = "*"
[[package]]
name = "opencv-python-headless"
version = "4.6.0.66"
@ -1029,6 +1056,18 @@ python-versions = "*"
[package.dependencies]
pyasn1 = ">=0.4.6,<0.5.0"
[[package]]
name = "pycocotools"
version = "2.0.4"
description = "Official APIs for the MS-COCO dataset"
category = "main"
optional = false
python-versions = ">=3.5"
[package.dependencies]
matplotlib = ">=2.1.0"
numpy = "*"
[[package]]
name = "pycparser"
version = "2.21"
@ -1049,7 +1088,7 @@ python-versions = ">=3.6"
name = "pygments"
version = "2.13.0"
description = "Pygments is a syntax highlighting package written in Python."
category = "dev"
category = "main"
optional = false
python-versions = ">=3.6"
@ -1071,7 +1110,7 @@ diagrams = ["jinja2", "railroad-diagrams"]
name = "pyreadline3"
version = "3.4.1"
description = "A python implementation of GNU readline."
category = "dev"
category = "main"
optional = false
python-versions = "*"
@ -1079,7 +1118,7 @@ python-versions = "*"
name = "python-dateutil"
version = "2.8.2"
description = "Extensions to the standard Python datetime module"
category = "dev"
category = "main"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
@ -1205,6 +1244,22 @@ requests = ">=2.0.0"
[package.extras]
rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
[[package]]
name = "rich"
version = "12.5.1"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
category = "main"
optional = false
python-versions = ">=3.6.3,<4.0.0"
[package.dependencies]
commonmark = ">=0.9.0,<0.10.0"
pygments = ">=2.6.0,<3.0.0"
typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""}
[package.extras]
jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"]
[[package]]
name = "rsa"
version = "4.9"
@ -1318,7 +1373,7 @@ test = ["pytest"]
name = "setuptools-scm"
version = "6.4.2"
description = "the blessed package to manage your versions by scm tags"
category = "dev"
category = "main"
optional = false
python-versions = ">=3.6"
@ -1374,7 +1429,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
name = "sympy"
version = "1.11.1"
description = "Computer algebra system (CAS) in Python"
category = "dev"
category = "main"
optional = false
python-versions = ">=3.8"
@ -1460,7 +1515,7 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
name = "tomli"
version = "2.0.1"
description = "A lil' TOML parser"
category = "dev"
category = "main"
optional = false
python-versions = ">=3.7"
@ -1673,7 +1728,7 @@ testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>=
[metadata]
lock-version = "1.1"
python-versions = ">=3.8,<3.11"
content-hash = "41f098cb3b123362c4c88942260ee0fd07c8b175d0f7f4481ab9a6e4dc401d7f"
content-hash = "a937d2dadd3250c71c1d8a43a3a4a72ac53e6f3793d856c17a4f4f199de71d0b"
[metadata.files]
absl-py = [
@ -1901,6 +1956,10 @@ coloredlogs = [
{file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"},
{file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"},
]
commonmark = [
{file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"},
{file = "commonmark-0.9.1.tar.gz", hash = "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60"},
]
cycler = [
{file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
{file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
@ -2483,6 +2542,16 @@ onnxruntime = [
{file = "onnxruntime-1.12.1-cp39-cp39-win32.whl", hash = "sha256:a9954f6ffab4a0a3877a4800d817950a236a6db4901399eec1ea52033f52da94"},
{file = "onnxruntime-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:76bbd92cbcc5b6b0f893565f072e33f921ae3350a77b74fb7c65757e683516c7"},
]
onnxruntime-gpu = [
{file = "onnxruntime_gpu-1.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42b0393c5122ed90fa2eb76192a486261d86e9526ccb78b2a98923c22791d2d1"},
{file = "onnxruntime_gpu-1.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:ecfe97335027e569d4f46725ba89316041e562b8c499690e25e11cfee4601cd1"},
{file = "onnxruntime_gpu-1.12.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2be6f7f5a1ce0bc8471ce42e10eab92cfb19d0748b857edcb5320b5e98311b7"},
{file = "onnxruntime_gpu-1.12.1-cp37-cp37m-win_amd64.whl", hash = "sha256:d73204323aefebe4eecab9fcf76e22b1a00394e3d838c2962a28a27301186b73"},
{file = "onnxruntime_gpu-1.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b37872527d03d3df10756408ca44014bd6ac354a044ab1c4286cd42dc138e518"},
{file = "onnxruntime_gpu-1.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:296bd9733986cb7517d15bef5535c555d3f863963a71e6575e92d2a854aee61d"},
{file = "onnxruntime_gpu-1.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e46d0724ce54c5908c5760037b78de741fbd48962b370c29ebc20e608b30eda"},
{file = "onnxruntime_gpu-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:fd919373be35b9ba54210688265df38ad5e19a530449385c40dab51da407345d"},
]
opencv-python-headless = [
{file = "opencv-python-headless-4.6.0.66.tar.gz", hash = "sha256:d5291d7e10aa2c19cab6fd86f0d61af8617290ecd2d7ffcb051e446868d04cc5"},
{file = "opencv_python_headless-4.6.0.66-cp36-abi3-macosx_10_15_x86_64.whl", hash = "sha256:21e70f8b0c04098cdf466d27184fe6c3820aaef944a22548db95099959c95889"},
@ -2694,6 +2763,9 @@ pyasn1-modules = [
{file = "pyasn1_modules-0.2.8-py3.6.egg", hash = "sha256:cbac4bc38d117f2a49aeedec4407d23e8866ea4ac27ff2cf7fb3e5b570df19e0"},
{file = "pyasn1_modules-0.2.8-py3.7.egg", hash = "sha256:c29a5e5cc7a3f05926aff34e097e84f8589cd790ce0ed41b67aed6857b26aafd"},
]
pycocotools = [
{file = "pycocotools-2.0.4.tar.gz", hash = "sha256:2ab586aa389b9657b6d73c2b9a827a3681f8d00f36490c2e8ab05902e3fd9e93"},
]
pycparser = [
{file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"},
{file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"},
@ -2892,6 +2964,10 @@ requests-oauthlib = [
{file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"},
{file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"},
]
rich = [
{file = "rich-12.5.1-py3-none-any.whl", hash = "sha256:2eb4e6894cde1e017976d2975ac210ef515d7548bc595ba20e195fb9628acdeb"},
{file = "rich-12.5.1.tar.gz", hash = "sha256:63a5c5ce3673d3d5fbbf23cd87e11ab84b6b451436f1b7f19ec54b6bc36ed7ca"},
]
rsa = [
{file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
{file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},

View file

@ -14,14 +14,17 @@ torch = "^1.12.1"
torchmetrics = "^0.9.3"
torchvision = "^0.13.1"
wandb = "^0.13.2"
rich = "^12.5.1"
pycocotools = "^2.0.4"
[tool.poetry.dev-dependencies]
black = {extras = ["jupyter"], version = "^22.8.0"}
black = { extras = ["jupyter"], version = "^22.8.0" }
ipykernel = "^6.15.2"
isort = "^5.10.1"
matplotlib = "^3.5.3"
mypy = "^0.971"
onnxruntime = "^1.12.1"
onnxruntime-gpu = "^1.12.1"
pre-commit = "^2.20.0"
[build-system]

View file

@ -1,9 +1,8 @@
import albumentations as A
import pytorch_lightning as pl
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Subset
import wandb
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
from .dataset import RealDataset
@ -20,17 +19,18 @@ class Spheres(pl.LightningDataModule):
transforms = A.Compose(
[
A.Flip(),
A.ColorJitter(),
A.ToGray(p=0.01),
A.GaussianBlur(),
A.MotionBlur(),
A.ISONoise(),
A.ImageCompression(),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
max_pixel_value=255,
), # [0, 255] -> coco (?) normalized
# A.ColorJitter(),
# A.ToGray(p=0.01),
# A.GaussianBlur(),
# A.MotionBlur(),
# A.ISONoise(),
# A.ImageCompression(),
# A.Normalize(
# mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225],
# max_pixel_value=255,
# ), # [0, 255] -> coco (?) normalized
A.ToFloat(max_value=255),
ToTensorV2(), # HWC -> CHW
],
bbox_params=A.BboxParams(
@ -57,11 +57,12 @@ class Spheres(pl.LightningDataModule):
def val_dataloader(self):
transforms = A.Compose(
[
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
max_pixel_value=255,
), # [0, 255] -> [0.0, 1.0] normalized
# A.Normalize(
# mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225],
# max_pixel_value=255,
# ), # [0, 255] -> [0.0, 1.0] normalized
A.ToFloat(max_value=255),
ToTensorV2(), # HWC -> CHW
],
bbox_params=A.BboxParams(

View file

@ -3,16 +3,14 @@
import pytorch_lightning as pl
import torch
import torchvision
import wandb
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import (
MaskRCNN_ResNet50_FPN_Weights,
MaskRCNNPredictor,
)
import wandb
from utils.coco_eval import CocoEvaluator
from utils.coco_utils import get_coco_api_from_dataset, get_iou_types
def get_model_instance_segmentation(num_classes):
# load an instance segmentation model pre-trained on COCO
@ -33,11 +31,10 @@ def get_model_instance_segmentation(num_classes):
class MRCNNModule(pl.LightningModule):
def __init__(self, hidden_layer_size, n_classes):
def __init__(self, n_classes):
super().__init__()
# Hyperparameters
self.hidden_layers_size = hidden_layer_size
self.n_classes = n_classes
# log hyperparameters
@ -46,10 +43,12 @@ class MRCNNModule(pl.LightningModule):
# Network
self.model = get_model_instance_segmentation(n_classes)
# pycoco evaluator
self.coco = None
self.iou_types = get_iou_types(self.model)
self.coco_evaluator = None
# onnx
self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True)
def forward(self, imgs):
self.model.eval()
return self.model(imgs)
def training_step(self, batch, batch_idx):
# unpack batch
@ -67,20 +66,17 @@ class MRCNNModule(pl.LightningModule):
return loss
def on_validation_epoch_start(self):
if self.coco is None:
self.coco = get_coco_api_from_dataset(self.trainer.val_dataloaders[0].dataset)
# init coco evaluator
self.coco_evaluator = CocoEvaluator(self.coco, self.iou_types)
self.metric = MeanAveragePrecision(iou_type="bbox")
def validation_step(self, batch, batch_idx):
# unpack batch
images, targets = batch
# compute metrics using pycocotools
outputs = self.model(images)
res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
self.coco_evaluator.update(res)
preds = self.model(images)
for pred, target in zip(preds, targets):
pred["masks"] = pred["masks"].squeeze(1).bool()
target["masks"] = target["masks"].squeeze(1).bool()
self.metric.update(preds, targets)
# compute validation loss
self.model.train()
@ -93,48 +89,22 @@ class MRCNNModule(pl.LightningModule):
def validation_epoch_end(self, outputs):
# log validation loss
loss_dict = {k: torch.stack([d[k] for d in outputs]).mean() for k in outputs[0].keys()}
loss_dict = {
k: torch.stack([d[k] for d in outputs]).mean() for k in outputs[0].keys()
} # TODO: update un dict object
self.log_dict(loss_dict)
# accumulate all predictions
self.coco_evaluator.accumulate()
self.coco_evaluator.summarize()
YEET = {
"valid,bbox,AP,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[0],
"valid,bbox,AP,IoU=0.50,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[1],
"valid,bbox,AP,IoU=0.75,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[2],
"valid,bbox,AP,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[3],
"valid,bbox,AP,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[4],
"valid,bbox,AP,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[5],
"valid,bbox,AR,IoU=0.50:0.,area=all,maxDets=1": self.coco_evaluator.coco_eval["bbox"].stats[6],
"valid,bbox,AR,IoU=0.50:0.,area=all,maxDets=10": self.coco_evaluator.coco_eval["bbox"].stats[7],
"valid,bbox,AR,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[8],
"valid,bbox,AR,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[9],
"valid,bbox,AR,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[10],
"valid,bbox,AR,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[11],
"valid,segm,AP,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[0],
"valid,segm,AP,IoU=0.50,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[1],
"valid,segm,AP,IoU=0.75,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[2],
"valid,segm,AP,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[3],
"valid,segm,AP,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[4],
"valid,segm,AP,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[5],
"valid,segm,AR,IoU=0.50:0.,area=all,maxDets=1": self.coco_evaluator.coco_eval["segm"].stats[6],
"valid,segm,AR,IoU=0.50:0.,area=all,maxDets=10": self.coco_evaluator.coco_eval["segm"].stats[7],
"valid,segm,AR,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[8],
"valid,segm,AR,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[9],
"valid,segm,AR,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[10],
"valid,segm,AR,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[11],
}
self.log_dict(YEET)
# log metrics
metric_dict = self.metric.compute()
metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()}
self.log_dict(metric_dict)
def configure_optimizers(self):
optimizer = torch.optim.SGD(
optimizer = torch.optim.Adam(
self.parameters(),
lr=wandb.config.LEARNING_RATE,
momentum=wandb.config.MOMENTUM,
weight_decay=wandb.config.WEIGHT_DECAY,
# momentum=wandb.config.MOMENTUM,
# weight_decay=wandb.config.WEIGHT_DECAY,
)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(

View file

@ -1 +1 @@
dbe91cbc0788d93595dac272825aa57412f04d70
1163aa4f284163ecb7d4289c4c5217ea28f63770

View file

@ -1,152 +0,0 @@
"""Pytorch lightning wrapper for model."""
import pytorch_lightning as pl
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import wandb
def get_model_instance_segmentation(num_classes):
# load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# now get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
# and replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
return model
class MRCNNModule(pl.LightningModule):
def __init__(self, hidden_layer_size, n_classes):
super().__init__()
# Hyperparameters
self.hidden_layers_size = hidden_layer_size
self.n_classes = n_classes
# log hyperparameters
self.save_hyperparameters()
# Network
self.model = get_model_instance_segmentation(n_classes)
# onnx
self.example_input_array = torch.randn(1, 3, 512, 512, requires_grad=True)
def forward(self, imgs):
self.model.eval()
return self.model(imgs)
def training_step(self, batch, batch_idx):
# unpack batch
images, targets = batch
# enable train mode
# self.model.train()
# fasterrcnn takes both images and targets for training
loss_dict = self.model(images, targets)
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
loss = sum(loss_dict.values())
# log everything
self.log_dict(loss_dict)
self.log("train/loss", loss)
return {"loss": loss, "log": loss_dict}
# def validation_step(self, batch, batch_idx):
# # unpack batch
# images, targets = batch
# # enable eval mode
# # self.detector.eval()
# # make a prediction
# preds = self.model(images)
# # compute validation loss
# self.val_loss = torch.mean(
# torch.stack(
# [
# self.accuracy(
# target,
# pred["boxes"],
# iou_threshold=0.5,
# )
# for target, pred in zip(targets, preds)
# ],
# )
# )
# return self.val_loss
# def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0):
# """
# The accuracy method is not the one used in the evaluator but very similar
# """
# total_gt = len(src_boxes)
# total_pred = len(pred_boxes)
# if total_gt > 0 and total_pred > 0:
# # Define the matcher and distance matrix based on iou
# matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False)
# match_quality_matrix = box_iou(src_boxes, pred_boxes)
# results = matcher(match_quality_matrix)
# true_positive = torch.count_nonzero(results.unique() != -1)
# matched_elements = results[results > -1]
# # in Matcher, a pred element can be matched only twice
# false_positive = torch.count_nonzero(results == -1) + (
# len(matched_elements) - len(matched_elements.unique())
# )
# false_negative = total_gt - true_positive
# return true_positive / (true_positive + false_positive + false_negative)
# elif total_gt == 0:
# if total_pred > 0:
# return torch.tensor(0.0).cuda()
# else:
# return torch.tensor(1.0).cuda()
# elif total_gt > 0 and total_pred == 0:
# return torch.tensor(0.0).cuda()
def configure_optimizers(self):
optimizer = torch.optim.SGD(
self.parameters(),
lr=wandb.config.LEARNING_RATE,
momentum=wandb.config.MOMENTUM,
weight_decay=wandb.config.WEIGHT_DECAY,
)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
# optimizer,
# T_0=3,
# T_mult=1,
# lr=wandb.config.LEARNING_RATE_MIN,
# verbose=True,
# )
# return {
# "optimizer": optimizer,
# "lr_scheduler": {
# "scheduler": scheduler,
# "monitor": "val_accuracy",
# },
# }
return optimizer

View file

@ -1 +1 @@
df787682f8ee4387834b713d1d48437b94d45f61
ecf0b9ce39e210bc605fd3eab9db8b1215c35fda

70
src/notebooks/test.py Normal file
View file

@ -0,0 +1,70 @@
import torch
from torchmetrics.detection.mean_ap import MeanAveragePrecision
preds = [
dict(
boxes=torch.tensor(
[
[880.0560, 41.7845, 966.9839, 131.3355],
[1421.0029, 682.4420, 1512.7570, 765.2380],
[132.0775, 818.5026, 216.0825, 1020.8573],
]
),
scores=torch.tensor(
[0.9989, 0.9936, 0.0932],
),
labels=torch.tensor(
[1, 1, 1],
),
)
]
target = [
dict(
boxes=torch.tensor(
[[879, 39, 1513, 766]],
),
labels=torch.tensor(
[1],
),
)
]
metric = MeanAveragePrecision()
metric.update(preds, target)
from pprint import pprint
pprint(metric.compute())
# --------------------------------------------------------------------------
preds = [
dict(
boxes=torch.tensor(
[
[880.0560, 41.7845, 1500, 700.3355],
]
),
scores=torch.tensor(
[0.9989],
),
labels=torch.tensor(
[1],
),
)
]
target = [
dict(
boxes=torch.tensor(
[[879, 39, 1513, 766]],
),
labels=torch.tensor(
[1],
),
)
]
metric = MeanAveragePrecision()
metric.update(preds, target)
from pprint import pprint
pprint(metric.compute())

View file

@ -1,43 +1,52 @@
import logging
import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar
import torch
import wandb
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from pytorch_lightning.loggers import WandbLogger
import wandb
from data import Spheres
from mrcnn import MRCNNModule
from utils import ArtifactLog, TableLog
if __name__ == "__main__":
# setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s: %(message)s",
)
# setup wandb
logger = WandbLogger(
project="Mask R-CNN",
config="wandb.yaml",
save_dir="/tmp/",
log_model="all",
settings=wandb.Settings(
code_dir="./src/",
),
)
# seed random generators
pl.seed_everything(wandb.config.SEED, workers=True)
pl.seed_everything(
seed=wandb.config.SEED,
workers=True,
)
# Create Network
module = MRCNNModule(
hidden_layer_size=-1,
n_classes=2,
)
# load checkpoint
# state_dict = torch.load("checkpoints/synth.pth")
# state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()])
# model.load_state_dict(state_dict)
# module.load_state_dict(torch.load()["state_dict"])
# module.load_from_checkpoint("/tmp/model.ckpt")
# log gradients and weights regularly
logger.watch(module.model, log="all")
logger.watch(
model=module.model,
log="all",
)
# Create the dataloaders
datamodule = Spheres()
@ -52,15 +61,17 @@ if __name__ == "__main__":
logger=logger,
log_every_n_steps=5,
val_check_interval=50,
callbacks=[RichProgressBar(), ArtifactLog()],
# callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
callbacks=[
ModelCheckpoint(monitor="valid/loss", mode="min"),
RichProgressBar(),
],
# profiler="advanced",
num_sanity_val_steps=3,
devices=[0],
devices=[1],
)
# actually train the model
trainer.fit(model=module, datamodule=datamodule)
# stop wandb
wandb.run.finish()
wandb.run.finish() # type: ignore

View file

@ -1,2 +1,2 @@
from .callback import ArtifactLog, TableLog
from .callback import TableLog
from .paste import RandomPaste

View file

@ -1,8 +1,5 @@
import numpy as np
import torch
from pytorch_lightning.callbacks import Callback
import wandb
from pytorch_lightning.callbacks import Callback
columns = [
"ID",
@ -64,34 +61,3 @@ class TableLog(Callback):
)
}
)
class ArtifactLog(Callback):
# def on_fit_start(self, trainer, pl_module):
# self.best = 1
def on_train_epoch_end(self, trainer, pl_module):
# create checkpoint
torch.save(pl_module.state_dict(), "checkpoints/model.pth")
# def on_validation_epoch_start(self, trainer, pl_module):
# self.dices = []
# def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# # unpacking
# metrics, _ = outputs
# self.dices.append(metrics["dice"].cpu())
# def on_validation_epoch_end(self, trainer, pl_module):
# dice = np.mean(self.dices)
# if dice < self.best:
# self.best = dice
# # create checkpoint
# trainer.save_checkpoint("checkpoints/model.ckpt")
# # log artifact
# artifact = wandb.Artifact("ckpt", type="model")
# artifact.add_file("checkpoints/model.ckpt")
# wandb.run.log_artifact(artifact)

View file

@ -1,9 +1,9 @@
DIR_TRAIN_IMG:
value: "/media/disk1/lfainsin/BACKGROUND/"
DIR_VALID_IMG:
value: "/media/disk1/lfainsin/TEST_batched/"
DIR_SPHERE:
value: "/media/disk1/lfainsin/SPHERES/"
# DIR_TRAIN_IMG:
# value: "/media/disk1/lfainsin/BACKGROUND/"
# DIR_VALID_IMG:
# value: "/media/disk1/lfainsin/TEST_batched/"
# DIR_SPHERE:
# value: "/media/disk1/lfainsin/SPHERES/"
N_CHANNELS:
value: 3
@ -28,7 +28,7 @@ WORKERS:
value: 16
EPOCHS:
value: 10
value: 50
TRAIN_BATCH_SIZE:
value: 10
VALID_BATCH_SIZE:
@ -37,7 +37,7 @@ PREFETCH_FACTOR:
value: 2
LEARNING_RATE:
value: 0.005
value: 0.0005
WEIGHT_DECAY:
value: 0.0005
MOMENTUM: