diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 63ac6b7..a441a3c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - id: check-executables-have-shebangs - id: check-merge-conflict - id: check-symlinks - - id: check-json + # - id: check-json - id: check-toml - id: check-yaml - id: debug-statements diff --git a/.vscode/launch.json b/.vscode/launch.json index bdbded8..191d204 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -2,20 +2,28 @@ "version": "0.2.0", "configurations": [ { - "name": "Python: Current File", + "name": "Train", "type": "python", "request": "launch", - "program": "${file}", + "program": "${workspaceFolder}/src/train.py", "console": "integratedTerminal", + "justMyCode": false, + }, + { + "name": "Predict", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/src/predict.py", + "console": "integratedTerminal", + "justMyCode": false, "args": [ "--input", - "images/test.png", + "images/input.png", "--output", - "output_onnx.png", + "images/output.png", "--model", - "good.onnx" - ], - "justMyCode": false + "checkpoints/model.onnx" + ] } ] } diff --git a/.vscode/settings.json b/.vscode/settings.json index e6d4181..446f92c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -20,5 +20,6 @@ "**/.DS_Store": true, "**/Thumbs.db": true, "**/__pycache__": true, + "**/.mypy_cache": true, } } diff --git a/poetry.lock b/poetry.lock index c33813d..bcf0f65 100644 --- a/poetry.lock +++ b/poetry.lock @@ -203,7 +203,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" name = "coloredlogs" version = "15.0.1" description = "Colored terminal output for Python's logging module" -category = "dev" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" @@ -213,11 +213,22 @@ humanfriendly = ">=9.1" [package.extras] cron = ["capturer (>=2.4)"] +[[package]] +name = "commonmark" +version = "0.9.1" +description = "Python parser for the CommonMark Markdown spec" +category = "main" +optional = false +python-versions = "*" + +[package.extras] +test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"] + [[package]] name = "cycler" version = "0.11.0" description = "Composable style cycles" -category = "dev" +category = "main" optional = false python-versions = ">=3.6" @@ -288,7 +299,7 @@ testing = ["covdefaults (>=2.2)", "coverage (>=6.4.2)", "pytest (>=7.1.2)", "pyt name = "flatbuffers" version = "2.0.7" description = "The FlatBuffers serialization format for Python" -category = "dev" +category = "main" optional = false python-versions = "*" @@ -296,7 +307,7 @@ python-versions = "*" name = "fonttools" version = "4.37.1" description = "Tools to manipulate font files" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -432,7 +443,7 @@ protobuf = ["grpcio-tools (>=1.48.1)"] name = "humanfriendly" version = "10.0" description = "Human friendly output for text interfaces using Python" -category = "dev" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" @@ -639,7 +650,7 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] name = "kiwisolver" version = "1.4.4" description = "A fast implementation of the Cassowary constraint solver" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -669,7 +680,7 @@ python-versions = ">=3.7" name = "matplotlib" version = "3.5.3" description = "Python plotting package" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -699,7 +710,7 @@ traitlets = "*" name = "mpmath" version = "1.2.1" description = "Python library for arbitrary-precision floating-point arithmetic" -category = "dev" +category = "main" optional = false python-versions = "*" @@ -825,6 +836,22 @@ packaging = "*" protobuf = "*" sympy = "*" +[[package]] +name = "onnxruntime-gpu" +version = "1.12.1" +description = "ONNX Runtime is a runtime accelerator for Machine Learning models" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +coloredlogs = "*" +flatbuffers = "*" +numpy = ">=1.21.0" +packaging = "*" +protobuf = "*" +sympy = "*" + [[package]] name = "opencv-python-headless" version = "4.6.0.66" @@ -1029,6 +1056,18 @@ python-versions = "*" [package.dependencies] pyasn1 = ">=0.4.6,<0.5.0" +[[package]] +name = "pycocotools" +version = "2.0.4" +description = "Official APIs for the MS-COCO dataset" +category = "main" +optional = false +python-versions = ">=3.5" + +[package.dependencies] +matplotlib = ">=2.1.0" +numpy = "*" + [[package]] name = "pycparser" version = "2.21" @@ -1049,7 +1088,7 @@ python-versions = ">=3.6" name = "pygments" version = "2.13.0" description = "Pygments is a syntax highlighting package written in Python." -category = "dev" +category = "main" optional = false python-versions = ">=3.6" @@ -1071,7 +1110,7 @@ diagrams = ["jinja2", "railroad-diagrams"] name = "pyreadline3" version = "3.4.1" description = "A python implementation of GNU readline." -category = "dev" +category = "main" optional = false python-versions = "*" @@ -1079,7 +1118,7 @@ python-versions = "*" name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" -category = "dev" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" @@ -1205,6 +1244,22 @@ requests = ">=2.0.0" [package.extras] rsa = ["oauthlib[signedtoken] (>=3.0.0)"] +[[package]] +name = "rich" +version = "12.5.1" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +category = "main" +optional = false +python-versions = ">=3.6.3,<4.0.0" + +[package.dependencies] +commonmark = ">=0.9.0,<0.10.0" +pygments = ">=2.6.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""} + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] + [[package]] name = "rsa" version = "4.9" @@ -1318,7 +1373,7 @@ test = ["pytest"] name = "setuptools-scm" version = "6.4.2" description = "the blessed package to manage your versions by scm tags" -category = "dev" +category = "main" optional = false python-versions = ">=3.6" @@ -1374,7 +1429,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] name = "sympy" version = "1.11.1" description = "Computer algebra system (CAS) in Python" -category = "dev" +category = "main" optional = false python-versions = ">=3.8" @@ -1460,7 +1515,7 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" name = "tomli" version = "2.0.1" description = "A lil' TOML parser" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -1673,7 +1728,7 @@ testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>= [metadata] lock-version = "1.1" python-versions = ">=3.8,<3.11" -content-hash = "41f098cb3b123362c4c88942260ee0fd07c8b175d0f7f4481ab9a6e4dc401d7f" +content-hash = "a937d2dadd3250c71c1d8a43a3a4a72ac53e6f3793d856c17a4f4f199de71d0b" [metadata.files] absl-py = [ @@ -1901,6 +1956,10 @@ coloredlogs = [ {file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"}, {file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"}, ] +commonmark = [ + {file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"}, + {file = "commonmark-0.9.1.tar.gz", hash = "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60"}, +] cycler = [ {file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"}, {file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"}, @@ -2483,6 +2542,16 @@ onnxruntime = [ {file = "onnxruntime-1.12.1-cp39-cp39-win32.whl", hash = "sha256:a9954f6ffab4a0a3877a4800d817950a236a6db4901399eec1ea52033f52da94"}, {file = "onnxruntime-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:76bbd92cbcc5b6b0f893565f072e33f921ae3350a77b74fb7c65757e683516c7"}, ] +onnxruntime-gpu = [ + {file = "onnxruntime_gpu-1.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42b0393c5122ed90fa2eb76192a486261d86e9526ccb78b2a98923c22791d2d1"}, + {file = "onnxruntime_gpu-1.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:ecfe97335027e569d4f46725ba89316041e562b8c499690e25e11cfee4601cd1"}, + {file = "onnxruntime_gpu-1.12.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2be6f7f5a1ce0bc8471ce42e10eab92cfb19d0748b857edcb5320b5e98311b7"}, + {file = "onnxruntime_gpu-1.12.1-cp37-cp37m-win_amd64.whl", hash = "sha256:d73204323aefebe4eecab9fcf76e22b1a00394e3d838c2962a28a27301186b73"}, + {file = "onnxruntime_gpu-1.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b37872527d03d3df10756408ca44014bd6ac354a044ab1c4286cd42dc138e518"}, + {file = "onnxruntime_gpu-1.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:296bd9733986cb7517d15bef5535c555d3f863963a71e6575e92d2a854aee61d"}, + {file = "onnxruntime_gpu-1.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e46d0724ce54c5908c5760037b78de741fbd48962b370c29ebc20e608b30eda"}, + {file = "onnxruntime_gpu-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:fd919373be35b9ba54210688265df38ad5e19a530449385c40dab51da407345d"}, +] opencv-python-headless = [ {file = "opencv-python-headless-4.6.0.66.tar.gz", hash = "sha256:d5291d7e10aa2c19cab6fd86f0d61af8617290ecd2d7ffcb051e446868d04cc5"}, {file = "opencv_python_headless-4.6.0.66-cp36-abi3-macosx_10_15_x86_64.whl", hash = "sha256:21e70f8b0c04098cdf466d27184fe6c3820aaef944a22548db95099959c95889"}, @@ -2694,6 +2763,9 @@ pyasn1-modules = [ {file = "pyasn1_modules-0.2.8-py3.6.egg", hash = "sha256:cbac4bc38d117f2a49aeedec4407d23e8866ea4ac27ff2cf7fb3e5b570df19e0"}, {file = "pyasn1_modules-0.2.8-py3.7.egg", hash = "sha256:c29a5e5cc7a3f05926aff34e097e84f8589cd790ce0ed41b67aed6857b26aafd"}, ] +pycocotools = [ + {file = "pycocotools-2.0.4.tar.gz", hash = "sha256:2ab586aa389b9657b6d73c2b9a827a3681f8d00f36490c2e8ab05902e3fd9e93"}, +] pycparser = [ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, @@ -2892,6 +2964,10 @@ requests-oauthlib = [ {file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"}, {file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"}, ] +rich = [ + {file = "rich-12.5.1-py3-none-any.whl", hash = "sha256:2eb4e6894cde1e017976d2975ac210ef515d7548bc595ba20e195fb9628acdeb"}, + {file = "rich-12.5.1.tar.gz", hash = "sha256:63a5c5ce3673d3d5fbbf23cd87e11ab84b6b451436f1b7f19ec54b6bc36ed7ca"}, +] rsa = [ {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, diff --git a/pyproject.toml b/pyproject.toml index f682d02..79d825d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,14 +14,17 @@ torch = "^1.12.1" torchmetrics = "^0.9.3" torchvision = "^0.13.1" wandb = "^0.13.2" +rich = "^12.5.1" +pycocotools = "^2.0.4" [tool.poetry.dev-dependencies] -black = {extras = ["jupyter"], version = "^22.8.0"} +black = { extras = ["jupyter"], version = "^22.8.0" } ipykernel = "^6.15.2" isort = "^5.10.1" matplotlib = "^3.5.3" mypy = "^0.971" onnxruntime = "^1.12.1" +onnxruntime-gpu = "^1.12.1" pre-commit = "^2.20.0" [build-system] diff --git a/src/data/dataloader.py b/src/data/dataloader.py index 9430ac8..3898ea5 100644 --- a/src/data/dataloader.py +++ b/src/data/dataloader.py @@ -1,9 +1,8 @@ import albumentations as A import pytorch_lightning as pl -from albumentations.pytorch import ToTensorV2 -from torch.utils.data import DataLoader, Subset - import wandb +from albumentations.pytorch import ToTensorV2 +from torch.utils.data import DataLoader from .dataset import RealDataset @@ -20,17 +19,18 @@ class Spheres(pl.LightningDataModule): transforms = A.Compose( [ A.Flip(), - A.ColorJitter(), - A.ToGray(p=0.01), - A.GaussianBlur(), - A.MotionBlur(), - A.ISONoise(), - A.ImageCompression(), - A.Normalize( - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225], - max_pixel_value=255, - ), # [0, 255] -> coco (?) normalized + # A.ColorJitter(), + # A.ToGray(p=0.01), + # A.GaussianBlur(), + # A.MotionBlur(), + # A.ISONoise(), + # A.ImageCompression(), + # A.Normalize( + # mean=[0.485, 0.456, 0.406], + # std=[0.229, 0.224, 0.225], + # max_pixel_value=255, + # ), # [0, 255] -> coco (?) normalized + A.ToFloat(max_value=255), ToTensorV2(), # HWC -> CHW ], bbox_params=A.BboxParams( @@ -57,11 +57,12 @@ class Spheres(pl.LightningDataModule): def val_dataloader(self): transforms = A.Compose( [ - A.Normalize( - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225], - max_pixel_value=255, - ), # [0, 255] -> [0.0, 1.0] normalized + # A.Normalize( + # mean=[0.485, 0.456, 0.406], + # std=[0.229, 0.224, 0.225], + # max_pixel_value=255, + # ), # [0, 255] -> [0.0, 1.0] normalized + A.ToFloat(max_value=255), ToTensorV2(), # HWC -> CHW ], bbox_params=A.BboxParams( diff --git a/src/mrcnn/module.py b/src/mrcnn/module.py index f2ece83..6b93a88 100644 --- a/src/mrcnn/module.py +++ b/src/mrcnn/module.py @@ -3,16 +3,14 @@ import pytorch_lightning as pl import torch import torchvision +import wandb +from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.mask_rcnn import ( MaskRCNN_ResNet50_FPN_Weights, MaskRCNNPredictor, ) -import wandb -from utils.coco_eval import CocoEvaluator -from utils.coco_utils import get_coco_api_from_dataset, get_iou_types - def get_model_instance_segmentation(num_classes): # load an instance segmentation model pre-trained on COCO @@ -33,11 +31,10 @@ def get_model_instance_segmentation(num_classes): class MRCNNModule(pl.LightningModule): - def __init__(self, hidden_layer_size, n_classes): + def __init__(self, n_classes): super().__init__() # Hyperparameters - self.hidden_layers_size = hidden_layer_size self.n_classes = n_classes # log hyperparameters @@ -46,10 +43,12 @@ class MRCNNModule(pl.LightningModule): # Network self.model = get_model_instance_segmentation(n_classes) - # pycoco evaluator - self.coco = None - self.iou_types = get_iou_types(self.model) - self.coco_evaluator = None + # onnx + self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True) + + def forward(self, imgs): + self.model.eval() + return self.model(imgs) def training_step(self, batch, batch_idx): # unpack batch @@ -67,20 +66,17 @@ class MRCNNModule(pl.LightningModule): return loss def on_validation_epoch_start(self): - if self.coco is None: - self.coco = get_coco_api_from_dataset(self.trainer.val_dataloaders[0].dataset) - - # init coco evaluator - self.coco_evaluator = CocoEvaluator(self.coco, self.iou_types) + self.metric = MeanAveragePrecision(iou_type="bbox") def validation_step(self, batch, batch_idx): # unpack batch images, targets = batch - # compute metrics using pycocotools - outputs = self.model(images) - res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} - self.coco_evaluator.update(res) + preds = self.model(images) + for pred, target in zip(preds, targets): + pred["masks"] = pred["masks"].squeeze(1).bool() + target["masks"] = target["masks"].squeeze(1).bool() + self.metric.update(preds, targets) # compute validation loss self.model.train() @@ -93,48 +89,22 @@ class MRCNNModule(pl.LightningModule): def validation_epoch_end(self, outputs): # log validation loss - loss_dict = {k: torch.stack([d[k] for d in outputs]).mean() for k in outputs[0].keys()} + loss_dict = { + k: torch.stack([d[k] for d in outputs]).mean() for k in outputs[0].keys() + } # TODO: update un dict object self.log_dict(loss_dict) - # accumulate all predictions - self.coco_evaluator.accumulate() - self.coco_evaluator.summarize() - - YEET = { - "valid,bbox,AP,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[0], - "valid,bbox,AP,IoU=0.50,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[1], - "valid,bbox,AP,IoU=0.75,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[2], - "valid,bbox,AP,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[3], - "valid,bbox,AP,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[4], - "valid,bbox,AP,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[5], - "valid,bbox,AR,IoU=0.50:0.,area=all,maxDets=1": self.coco_evaluator.coco_eval["bbox"].stats[6], - "valid,bbox,AR,IoU=0.50:0.,area=all,maxDets=10": self.coco_evaluator.coco_eval["bbox"].stats[7], - "valid,bbox,AR,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[8], - "valid,bbox,AR,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[9], - "valid,bbox,AR,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[10], - "valid,bbox,AR,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[11], - "valid,segm,AP,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[0], - "valid,segm,AP,IoU=0.50,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[1], - "valid,segm,AP,IoU=0.75,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[2], - "valid,segm,AP,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[3], - "valid,segm,AP,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[4], - "valid,segm,AP,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[5], - "valid,segm,AR,IoU=0.50:0.,area=all,maxDets=1": self.coco_evaluator.coco_eval["segm"].stats[6], - "valid,segm,AR,IoU=0.50:0.,area=all,maxDets=10": self.coco_evaluator.coco_eval["segm"].stats[7], - "valid,segm,AR,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[8], - "valid,segm,AR,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[9], - "valid,segm,AR,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[10], - "valid,segm,AR,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[11], - } - - self.log_dict(YEET) + # log metrics + metric_dict = self.metric.compute() + metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()} + self.log_dict(metric_dict) def configure_optimizers(self): - optimizer = torch.optim.SGD( + optimizer = torch.optim.Adam( self.parameters(), lr=wandb.config.LEARNING_RATE, - momentum=wandb.config.MOMENTUM, - weight_decay=wandb.config.WEIGHT_DECAY, + # momentum=wandb.config.MOMENTUM, + # weight_decay=wandb.config.WEIGHT_DECAY, ) # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( diff --git a/src/notebooks/convert.ipynb.REMOVED.git-id b/src/notebooks/convert.ipynb.REMOVED.git-id index fa5f3c7..c923235 100644 --- a/src/notebooks/convert.ipynb.REMOVED.git-id +++ b/src/notebooks/convert.ipynb.REMOVED.git-id @@ -1 +1 @@ -dbe91cbc0788d93595dac272825aa57412f04d70 \ No newline at end of file +1163aa4f284163ecb7d4289c4c5217ea28f63770 \ No newline at end of file diff --git a/src/notebooks/module.py b/src/notebooks/module.py deleted file mode 100644 index 71c2e5f..0000000 --- a/src/notebooks/module.py +++ /dev/null @@ -1,152 +0,0 @@ -"""Pytorch lightning wrapper for model.""" - -import pytorch_lightning as pl -import torch -import torchvision -from torchvision.models.detection.faster_rcnn import FastRCNNPredictor -from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor - -import wandb - - -def get_model_instance_segmentation(num_classes): - # load an instance segmentation model pre-trained on COCO - model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) - - # get number of input features for the classifier - in_features = model.roi_heads.box_predictor.cls_score.in_features - # replace the pre-trained head with a new one - model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) - - # now get the number of input features for the mask classifier - in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels - hidden_layer = 256 - # and replace the mask predictor with a new one - model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes) - - return model - - -class MRCNNModule(pl.LightningModule): - def __init__(self, hidden_layer_size, n_classes): - super().__init__() - - # Hyperparameters - self.hidden_layers_size = hidden_layer_size - self.n_classes = n_classes - - # log hyperparameters - self.save_hyperparameters() - - # Network - self.model = get_model_instance_segmentation(n_classes) - - # onnx - self.example_input_array = torch.randn(1, 3, 512, 512, requires_grad=True) - - def forward(self, imgs): - self.model.eval() - return self.model(imgs) - - def training_step(self, batch, batch_idx): - # unpack batch - images, targets = batch - - # enable train mode - # self.model.train() - - # fasterrcnn takes both images and targets for training - loss_dict = self.model(images, targets) - loss_dict = {f"train/{key}": val for key, val in loss_dict.items()} - loss = sum(loss_dict.values()) - - # log everything - self.log_dict(loss_dict) - self.log("train/loss", loss) - - return {"loss": loss, "log": loss_dict} - - # def validation_step(self, batch, batch_idx): - # # unpack batch - # images, targets = batch - - # # enable eval mode - # # self.detector.eval() - - # # make a prediction - # preds = self.model(images) - - # # compute validation loss - # self.val_loss = torch.mean( - # torch.stack( - # [ - # self.accuracy( - # target, - # pred["boxes"], - # iou_threshold=0.5, - # ) - # for target, pred in zip(targets, preds) - # ], - # ) - # ) - - # return self.val_loss - - # def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0): - # """ - # The accuracy method is not the one used in the evaluator but very similar - # """ - # total_gt = len(src_boxes) - # total_pred = len(pred_boxes) - # if total_gt > 0 and total_pred > 0: - - # # Define the matcher and distance matrix based on iou - # matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False) - # match_quality_matrix = box_iou(src_boxes, pred_boxes) - - # results = matcher(match_quality_matrix) - - # true_positive = torch.count_nonzero(results.unique() != -1) - # matched_elements = results[results > -1] - - # # in Matcher, a pred element can be matched only twice - # false_positive = torch.count_nonzero(results == -1) + ( - # len(matched_elements) - len(matched_elements.unique()) - # ) - # false_negative = total_gt - true_positive - - # return true_positive / (true_positive + false_positive + false_negative) - - # elif total_gt == 0: - # if total_pred > 0: - # return torch.tensor(0.0).cuda() - # else: - # return torch.tensor(1.0).cuda() - # elif total_gt > 0 and total_pred == 0: - # return torch.tensor(0.0).cuda() - - def configure_optimizers(self): - optimizer = torch.optim.SGD( - self.parameters(), - lr=wandb.config.LEARNING_RATE, - momentum=wandb.config.MOMENTUM, - weight_decay=wandb.config.WEIGHT_DECAY, - ) - - # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( - # optimizer, - # T_0=3, - # T_mult=1, - # lr=wandb.config.LEARNING_RATE_MIN, - # verbose=True, - # ) - - # return { - # "optimizer": optimizer, - # "lr_scheduler": { - # "scheduler": scheduler, - # "monitor": "val_accuracy", - # }, - # } - - return optimizer diff --git a/src/notebooks/predict.ipynb.REMOVED.git-id b/src/notebooks/predict.ipynb.REMOVED.git-id index e2538a4..42b7af2 100644 --- a/src/notebooks/predict.ipynb.REMOVED.git-id +++ b/src/notebooks/predict.ipynb.REMOVED.git-id @@ -1 +1 @@ -df787682f8ee4387834b713d1d48437b94d45f61 \ No newline at end of file +ecf0b9ce39e210bc605fd3eab9db8b1215c35fda \ No newline at end of file diff --git a/src/notebooks/test.py b/src/notebooks/test.py new file mode 100644 index 0000000..0121ce8 --- /dev/null +++ b/src/notebooks/test.py @@ -0,0 +1,70 @@ +import torch +from torchmetrics.detection.mean_ap import MeanAveragePrecision + +preds = [ + dict( + boxes=torch.tensor( + [ + [880.0560, 41.7845, 966.9839, 131.3355], + [1421.0029, 682.4420, 1512.7570, 765.2380], + [132.0775, 818.5026, 216.0825, 1020.8573], + ] + ), + scores=torch.tensor( + [0.9989, 0.9936, 0.0932], + ), + labels=torch.tensor( + [1, 1, 1], + ), + ) +] +target = [ + dict( + boxes=torch.tensor( + [[879, 39, 1513, 766]], + ), + labels=torch.tensor( + [1], + ), + ) +] +metric = MeanAveragePrecision() +metric.update(preds, target) + +from pprint import pprint + +pprint(metric.compute()) + +# -------------------------------------------------------------------------- + +preds = [ + dict( + boxes=torch.tensor( + [ + [880.0560, 41.7845, 1500, 700.3355], + ] + ), + scores=torch.tensor( + [0.9989], + ), + labels=torch.tensor( + [1], + ), + ) +] +target = [ + dict( + boxes=torch.tensor( + [[879, 39, 1513, 766]], + ), + labels=torch.tensor( + [1], + ), + ) +] +metric = MeanAveragePrecision() +metric.update(preds, target) + +from pprint import pprint + +pprint(metric.compute()) diff --git a/src/train.py b/src/train.py index 3556665..d0ec56e 100644 --- a/src/train.py +++ b/src/train.py @@ -1,43 +1,52 @@ import logging import pytorch_lightning as pl -from pytorch_lightning.callbacks import RichProgressBar +import torch +import wandb +from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar from pytorch_lightning.loggers import WandbLogger -import wandb from data import Spheres from mrcnn import MRCNNModule -from utils import ArtifactLog, TableLog if __name__ == "__main__": # setup logging - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + logging.basicConfig( + level=logging.INFO, + format="%(levelname)s: %(message)s", + ) # setup wandb logger = WandbLogger( project="Mask R-CNN", config="wandb.yaml", + save_dir="/tmp/", + log_model="all", settings=wandb.Settings( code_dir="./src/", ), ) # seed random generators - pl.seed_everything(wandb.config.SEED, workers=True) + pl.seed_everything( + seed=wandb.config.SEED, + workers=True, + ) # Create Network module = MRCNNModule( - hidden_layer_size=-1, n_classes=2, ) # load checkpoint - # state_dict = torch.load("checkpoints/synth.pth") - # state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()]) - # model.load_state_dict(state_dict) + # module.load_state_dict(torch.load()["state_dict"]) + # module.load_from_checkpoint("/tmp/model.ckpt") # log gradients and weights regularly - logger.watch(module.model, log="all") + logger.watch( + model=module.model, + log="all", + ) # Create the dataloaders datamodule = Spheres() @@ -52,15 +61,17 @@ if __name__ == "__main__": logger=logger, log_every_n_steps=5, val_check_interval=50, - callbacks=[RichProgressBar(), ArtifactLog()], - # callbacks=[RichProgressBar(), ArtifactLog(), TableLog()], + callbacks=[ + ModelCheckpoint(monitor="valid/loss", mode="min"), + RichProgressBar(), + ], # profiler="advanced", num_sanity_val_steps=3, - devices=[0], + devices=[1], ) # actually train the model trainer.fit(model=module, datamodule=datamodule) # stop wandb - wandb.run.finish() + wandb.run.finish() # type: ignore diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 2d4521b..969f4f0 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,2 +1,2 @@ -from .callback import ArtifactLog, TableLog +from .callback import TableLog from .paste import RandomPaste diff --git a/src/utils/callback.py b/src/utils/callback.py index 971cace..7908a73 100644 --- a/src/utils/callback.py +++ b/src/utils/callback.py @@ -1,8 +1,5 @@ -import numpy as np -import torch -from pytorch_lightning.callbacks import Callback - import wandb +from pytorch_lightning.callbacks import Callback columns = [ "ID", @@ -64,34 +61,3 @@ class TableLog(Callback): ) } ) - - -class ArtifactLog(Callback): - # def on_fit_start(self, trainer, pl_module): - # self.best = 1 - - def on_train_epoch_end(self, trainer, pl_module): - # create checkpoint - torch.save(pl_module.state_dict(), "checkpoints/model.pth") - - # def on_validation_epoch_start(self, trainer, pl_module): - # self.dices = [] - - # def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - # # unpacking - # metrics, _ = outputs - # self.dices.append(metrics["dice"].cpu()) - - # def on_validation_epoch_end(self, trainer, pl_module): - # dice = np.mean(self.dices) - - # if dice < self.best: - # self.best = dice - - # # create checkpoint - # trainer.save_checkpoint("checkpoints/model.ckpt") - - # # log artifact - # artifact = wandb.Artifact("ckpt", type="model") - # artifact.add_file("checkpoints/model.ckpt") - # wandb.run.log_artifact(artifact) diff --git a/wandb.yaml b/wandb.yaml index dada248..a6677cd 100644 --- a/wandb.yaml +++ b/wandb.yaml @@ -1,9 +1,9 @@ -DIR_TRAIN_IMG: - value: "/media/disk1/lfainsin/BACKGROUND/" -DIR_VALID_IMG: - value: "/media/disk1/lfainsin/TEST_batched/" -DIR_SPHERE: - value: "/media/disk1/lfainsin/SPHERES/" +# DIR_TRAIN_IMG: +# value: "/media/disk1/lfainsin/BACKGROUND/" +# DIR_VALID_IMG: +# value: "/media/disk1/lfainsin/TEST_batched/" +# DIR_SPHERE: +# value: "/media/disk1/lfainsin/SPHERES/" N_CHANNELS: value: 3 @@ -28,7 +28,7 @@ WORKERS: value: 16 EPOCHS: - value: 10 + value: 50 TRAIN_BATCH_SIZE: value: 10 VALID_BATCH_SIZE: @@ -37,7 +37,7 @@ PREFETCH_FACTOR: value: 2 LEARNING_RATE: - value: 0.005 + value: 0.0005 WEIGHT_DECAY: value: 0.0005 MOMENTUM: