diff --git a/.gitignore b/.gitignore index cdc3479..40eb468 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ __pycache__/ wandb/ images/ +lightning_logs/ checkpoints/ *.pth diff --git a/poetry.lock b/poetry.lock index ec4dac1..08ed37b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -189,6 +189,17 @@ category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +[[package]] +name = "commonmark" +version = "0.9.1" +description = "Python parser for the CommonMark Markdown spec" +category = "main" +optional = false +python-versions = "*" + +[package.extras] +test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"] + [[package]] name = "cycler" version = "0.11.0" @@ -881,7 +892,7 @@ python-versions = ">=3.6" name = "pygments" version = "2.12.0" description = "Pygments is a syntax highlighting package written in Python." -category = "dev" +category = "main" optional = false python-versions = ">=3.6" @@ -1027,6 +1038,22 @@ requests = ">=2.0.0" [package.extras] rsa = ["oauthlib[signedtoken] (>=3.0.0)"] +[[package]] +name = "rich" +version = "12.4.4" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +category = "main" +optional = false +python-versions = ">=3.6.3,<4.0.0" + +[package.dependencies] +commonmark = ">=0.9.0,<0.10.0" +pygments = ">=2.6.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""} + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] + [[package]] name = "rsa" version = "4.8" @@ -1446,7 +1473,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = ">=3.8,<3.11" -content-hash = "b192d0e5f593e99630bb92cd31c510dcdea67b0b54861176f92f50505724e7d5" +content-hash = "416650c968a0021f7d64028f272464d96319c361a72888ae4cb3e2a602873832" [metadata.files] absl-py = [ @@ -1652,6 +1679,10 @@ colorama = [ {file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"}, {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"}, ] +commonmark = [ + {file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"}, + {file = "commonmark-0.9.1.tar.gz", hash = "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60"}, +] cycler = [ {file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"}, {file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"}, @@ -2419,6 +2450,10 @@ requests-oauthlib = [ {file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"}, {file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"}, ] +rich = [ + {file = "rich-12.4.4-py3-none-any.whl", hash = "sha256:d2bbd99c320a2532ac71ff6a3164867884357da3e3301f0240090c5d2fdac7ec"}, + {file = "rich-12.4.4.tar.gz", hash = "sha256:4c586de507202505346f3e32d1363eb9ed6932f0c2f63184dea88983ff4971e2"}, +] rsa = [ {file = "rsa-4.8-py3-none-any.whl", hash = "sha256:95c5d300c4e879ee69708c428ba566c59478fd653cc3a22243eeb8ed846950bb"}, {file = "rsa-4.8.tar.gz", hash = "sha256:5c6bd9dc7a543b7fe4304a631f8a8a3b674e2bbfc49c2ae96200cdbe55df6b17"}, diff --git a/pyproject.toml b/pyproject.toml index 426fef6..ef17a5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ torch = "^1.11.0" torchvision = "^0.12.0" tqdm = "^4.64.0" wandb = "^0.12.19" +rich = "^12.4.4" [tool.poetry.dev-dependencies] black = "^22.3.0" diff --git a/src/train.py b/src/train.py index e8f3594..2471320 100644 --- a/src/train.py +++ b/src/train.py @@ -3,8 +3,8 @@ import logging import albumentations as A import pytorch_lightning as pl import torch -import yaml from albumentations.pytorch import ToTensorV2 +from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.loggers import WandbLogger from torch.utils.data import DataLoader @@ -13,8 +13,27 @@ from src.utils.dataset import SphereDataset from unet import UNet from utils.paste import RandomPaste -class_labels = { - 1: "sphere", +CONFIG = { + "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", + "DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/", + "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", + "DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/", + "DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/", + "FEATURES": [64, 128, 256, 512], + "N_CHANNELS": 3, + "N_CLASSES": 1, + "AMP": True, + "PIN_MEMORY": True, + "BENCHMARK": True, + "DEVICE": "gpu", + "WORKERS": 8, + "EPOCHS": 5, + "BATCH_SIZE": 16, + "LEARNING_RATE": 1e-4, + "WEIGHT_DECAY": 1e-8, + "MOMENTUM": 0.9, + "IMG_SIZE": 512, + "SPHERES": 5, } if __name__ == "__main__": @@ -24,28 +43,7 @@ if __name__ == "__main__": # setup wandb logger = WandbLogger( project="U-Net", - config=dict( - DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/", - DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/val/", - DIR_TEST_IMG="/home/lilian/data_disk/lfainsin/test/", - DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/", - DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/", - FEATURES=[64, 128, 256, 512], - N_CHANNELS=3, - N_CLASSES=1, - AMP=True, - PIN_MEMORY=True, - BENCHMARK=True, - DEVICE="gpu", - WORKERS=8, - EPOCHS=5, - BATCH_SIZE=16, - LEARNING_RATE=1e-4, - WEIGHT_DECAY=1e-8, - MOMENTUM=0.9, - IMG_SIZE=512, - SPHERES=5, - ), + config=CONFIG, settings=wandb.Settings( code_dir="./src/", ), @@ -55,10 +53,7 @@ if __name__ == "__main__": pl.seed_everything(69420, workers=True) # 0. Create network - net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES) - - # log the number of parameters of the model - wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad) + net = UNet(n_channels=CONFIG["N_CHANNELS"], n_classes=CONFIG["N_CLASSES"], features=CONFIG["FEATURES"]) # log gradients and weights regularly logger.watch(net, log="all") @@ -66,88 +61,59 @@ if __name__ == "__main__": # 1. Create transforms tf_train = A.Compose( [ - A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), + A.Resize(CONFIG["IMG_SIZE"], CONFIG["IMG_SIZE"]), A.Flip(), A.ColorJitter(), - RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), + RandomPaste(CONFIG["SPHERES"], CONFIG["DIR_SPHERE_IMG"], CONFIG["DIR_SPHERE_MASK"]), A.GaussianBlur(), A.ISONoise(), A.ToFloat(max_value=255), ToTensorV2(), ], ) - tf_valid = A.Compose( - [ - A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), - RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), - A.ToFloat(max_value=255), - ToTensorV2(), - ], - ) # 2. Create datasets - ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) - ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid) - ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG) + ds_train = SphereDataset(image_dir=CONFIG["DIR_TRAIN_IMG"], transform=tf_train) + ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"]) # 2.5. Create subset, if uncommented ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) - ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000))) - ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) + # ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100))) + # ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) # 3. Create data loaders train_loader = DataLoader( ds_train, shuffle=True, - batch_size=wandb.config.BATCH_SIZE, - num_workers=wandb.config.WORKERS, - pin_memory=wandb.config.PIN_MEMORY, + batch_size=CONFIG["BATCH_SIZE"], + num_workers=CONFIG["WORKERS"], + pin_memory=CONFIG["PIN_MEMORY"], ) val_loader = DataLoader( ds_valid, shuffle=False, drop_last=True, - batch_size=wandb.config.BATCH_SIZE, - num_workers=wandb.config.WORKERS, - pin_memory=wandb.config.PIN_MEMORY, - ) - test_loader = DataLoader( - ds_test, - shuffle=False, - drop_last=False, batch_size=1, - num_workers=wandb.config.WORKERS, - pin_memory=wandb.config.PIN_MEMORY, + num_workers=CONFIG["WORKERS"], + pin_memory=CONFIG["PIN_MEMORY"], ) # 4. Create the trainer trainer = pl.Trainer( - max_epochs=wandb.config.EPOCHS, - accelerator="gpu", - precision=16, + max_epochs=CONFIG["EPOCHS"], + accelerator=CONFIG["DEVICE"], + # precision=16, auto_scale_batch_size="binsearch", - benchmark=wandb.config.BENCHMARK, + benchmark=CONFIG["BENCHMARK"], val_check_interval=100, + callbacks=RichProgressBar(), ) - # print the config - logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") - - # # wandb init log - # wandb.log( - # { - # "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], - # }, - # commit=False, - # ) - try: trainer.fit( model=net, train_dataloaders=train_loader, val_dataloaders=val_loader, - test_dataloaders=test_loader, - accelerator=wandb.config.DEVICE, ) except KeyboardInterrupt: torch.save(net.state_dict(), "INTERRUPTED.pth") diff --git a/src/unet/model.py b/src/unet/model.py index 378b407..b9d6c18 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -1,7 +1,5 @@ """ Full assembly of the parts to form the complete network """ -from xmlrpc.server import list_public_methods - import numpy as np import pytorch_lightning as pl @@ -40,6 +38,7 @@ class UNet(pl.LightningModule): def forward(self, x): skips = [] + x = x.to(self.device) x = self.inc(x) for down in self.downs: @@ -53,8 +52,7 @@ class UNet(pl.LightningModule): return x - @staticmethod - def save_to_table(images, masks_true, masks_pred, masks_pred_bin, log_key): + def save_to_table(self, images, masks_true, masks_pred, masks_pred_bin, log_key): table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) for i, (img, mask, pred, pred_bin) in enumerate( @@ -99,16 +97,17 @@ class UNet(pl.LightningModule): accuracy = (masks_true == masks_pred_bin).float().mean() dice = dice_coeff(masks_pred_bin, masks_true) - wandb.log( + self.log( + "train", { - "train/accuracy": accuracy, - "train/bce": loss, - "train/dice": dice, - "train/mae": mae, - } + "accuracy": accuracy, + "bce": loss, + "dice": dice, + "mae": mae, + }, ) - return loss, dice, accuracy, mae + return loss # , dice, accuracy, mae def validation_step(self, batch, batch_idx): # unpacking @@ -119,79 +118,79 @@ class UNet(pl.LightningModule): # compute metrics loss = F.cross_entropy(masks_pred, masks_true) - mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) - accuracy = (masks_true == masks_pred_bin).float().mean() - dice = dice_coeff(masks_pred_bin, masks_true) + # mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) + # accuracy = (masks_true == masks_pred_bin).float().mean() + # dice = dice_coeff(masks_pred_bin, masks_true) if batch_idx == 0: self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions") - return loss, dice, accuracy, mae + return loss # , dice, accuracy, mae - def validation_step_end(self, validation_outputs): - # unpacking - loss, dice, accuracy, mae = validation_outputs - optimizer = self.optimizers[0] - learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] + # def validation_step_end(self, validation_outputs): + # # unpacking + # loss, dice, accuracy, mae = validation_outputs + # # optimizer = self.optimizers[0] + # # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] - wandb.log( - { - "train/learning_rate": learning_rate, - "val/accuracy": accuracy, - "val/bce": loss, - "val/dice": dice, - "val/mae": mae, - } - ) + # wandb.log( + # { + # # "train/learning_rate": learning_rate, + # "val/accuracy": accuracy, + # "val/bce": loss, + # "val/dice": dice, + # "val/mae": mae, + # } + # ) - # export model to onnx - dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True) - torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx") - artifact = wandb.Artifact("onnx", type="model") - artifact.add_file(f"checkpoints/model.onnx") - wandb.run.log_artifact(artifact) + # # export model to onnx + # dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True) + # torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx") + # artifact = wandb.Artifact("onnx", type="model") + # artifact.add_file(f"checkpoints/model.onnx") + # wandb.run.log_artifact(artifact) - def test_step(self, batch, batch_idx): - # unpacking - images, masks_true = batch - masks_true = masks_true.unsqueeze(1) - masks_pred = self(images) - masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + # def test_step(self, batch, batch_idx): + # # unpacking + # images, masks_true = batch + # masks_true = masks_true.unsqueeze(1) + # masks_pred = self(images) + # masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - # compute metrics - loss = F.cross_entropy(masks_pred, masks_true) - mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) - accuracy = (masks_true == masks_pred_bin).float().mean() - dice = dice_coeff(masks_pred_bin, masks_true) + # # compute metrics + # loss = F.cross_entropy(masks_pred, masks_true) + # mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) + # accuracy = (masks_true == masks_pred_bin).float().mean() + # dice = dice_coeff(masks_pred_bin, masks_true) - if batch_idx == 0: - self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions") + # if batch_idx == 0: + # self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions") - return loss, dice, accuracy, mae + # return loss, dice, accuracy, mae - def test_step_end(self, test_outputs): - # unpacking - list_loss, list_dice, list_accuracy, list_mae = test_outputs + # def test_step_end(self, test_outputs): + # # unpacking + # list_loss, list_dice, list_accuracy, list_mae = test_outputs - # averaging - loss = np.mean(list_loss) - dice = np.mean(list_dice) - accuracy = np.mean(list_accuracy) - mae = np.mean(list_mae) + # # averaging + # loss = np.mean(list_loss) + # dice = np.mean(list_dice) + # accuracy = np.mean(list_accuracy) + # mae = np.mean(list_mae) - # get learning rate - optimizer = self.optimizers[0] - learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] + # # # get learning rate + # # optimizer = self.optimizers[0] + # # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] - wandb.log( - { - "train/learning_rate": learning_rate, - "val/accuracy": accuracy, - "val/bce": loss, - "val/dice": dice, - "val/mae": mae, - } - ) + # wandb.log( + # { + # # "train/learning_rate": learning_rate, + # "test/accuracy": accuracy, + # "test/bce": loss, + # "test/dice": dice, + # "test/mae": mae, + # } + # ) def configure_optimizers(self): optimizer = torch.optim.RMSprop( @@ -200,10 +199,10 @@ class UNet(pl.LightningModule): weight_decay=wandb.config.WEIGHT_DECAY, momentum=wandb.config.MOMENTUM, ) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, - "max", - patience=2, - ) + # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + # optimizer, + # "max", + # patience=2, + # ) - return optimizer, scheduler + return optimizer # , scheduler