From 47c888cf6c9c4aba4106b21f53e8185b92c2eaf1 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Mon, 11 Jul 2022 15:34:05 +0200 Subject: [PATCH] feat: checkpoint wandb logging feat: wandb config file Former-commit-id: 45f56db86ca269b028cf76bf5315bc0eef8d2a21 [formerly e320b72e16eed02bdca05245e7c77914f0e288f9] Former-commit-id: 6d91318784748308c73dc6a164653f04ae46cd2a --- .editorconfig | 2 +- .vscode/launch.json | 2 +- src/config-defaults.yaml | 41 +++++++++++++++++++ src/data/dataloader.py | 15 +------ src/train.py | 63 ++++++++--------------------- src/utils/__init__.py | 2 +- src/utils/callback.py | 85 +++++++++++++++++++++++++++------------- 7 files changed, 119 insertions(+), 91 deletions(-) create mode 100644 src/config-defaults.yaml diff --git a/.editorconfig b/.editorconfig index c8cd2d4..8c28df6 100644 --- a/.editorconfig +++ b/.editorconfig @@ -11,5 +11,5 @@ charset = utf-8 trim_trailing_whitespace = true insert_final_newline = true -[*.{json,toml}] +[*.{json,toml,yaml,yml}] indent_size = 2 diff --git a/.vscode/launch.json b/.vscode/launch.json index a0ae3f2..fdb26d6 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -18,7 +18,7 @@ "--model", "good.onnx", ], - "justMyCode": true + "justMyCode": false } ] } diff --git a/src/config-defaults.yaml b/src/config-defaults.yaml new file mode 100644 index 0000000..4fcda67 --- /dev/null +++ b/src/config-defaults.yaml @@ -0,0 +1,41 @@ +DIR_TRAIN_IMG: + value: "/home/lilian/data_disk/lfainsin/train/" +DIR_VALID_IMG: + value: "/home/lilian/data_disk/lfainsin/test_batched_fast/" +DIR_SPHERE: + value: "/home/lilian/data_disk/lfainsin/spheres+real/" + +FEATURES: + value: { 8, 16, 32, 64 } +N_CHANNELS: + value: 3, +N_CLASSES: + value: 1, + +AMP: + value: True +PIN_MEMORY: + value: True +BENCHMARK: + value: True +DEVICE: + value: gpu +WORKERS: + value: 8 + +IMG_SIZE: + value: 512 +SPHERES: + value: 5 + +EPOCHS: + value: 10 +BATCH_SIZE: + value: 16 + +LEARNING_RATE: + value: 1e-4 +WEIGHT_DECAY: + value: 1e-8 +MOMENTUM: + value: 0.9 diff --git a/src/data/dataloader.py b/src/data/dataloader.py index 2257262..fe26b0b 100644 --- a/src/data/dataloader.py +++ b/src/data/dataloader.py @@ -36,19 +36,6 @@ class Spheres(pl.LightningDataModule): pin_memory=wandb.config.PIN_MEMORY, ) - # dataset = LabeledDataset(image_dir="/home/lilian/data_disk/lfainsin/prerender/") - # dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) - # dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1))) - - # return DataLoader( - # dataset, - # shuffle=True, - # batch_size=8, - # prefetch_factor=8, - # num_workers=wandb.config.WORKERS, - # pin_memory=wandb.config.PIN_MEMORY, - # ) - def val_dataloader(self): dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1))) @@ -56,7 +43,7 @@ class Spheres(pl.LightningDataModule): return DataLoader( dataset, shuffle=False, - batch_size=1, + batch_size=8, prefetch_factor=8, num_workers=wandb.config.WORKERS, pin_memory=wandb.config.PIN_MEMORY, diff --git a/src/train.py b/src/train.py index 6652d8a..1a7dbe5 100644 --- a/src/train.py +++ b/src/train.py @@ -1,44 +1,21 @@ import logging import pytorch_lightning as pl -import torch -from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar +from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.loggers import WandbLogger import wandb from data import Spheres from unet import UNetModule -from utils import TableLog - -CONFIG = { - "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", - "DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/test_batched_fast/", - "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres+real/", - "FEATURES": [8, 16, 32, 64], - "N_CHANNELS": 3, - "N_CLASSES": 1, - "AMP": True, - "PIN_MEMORY": True, - "BENCHMARK": True, - "DEVICE": "gpu", - "WORKERS": 8, - "EPOCHS": 10, - "BATCH_SIZE": 16, - "LEARNING_RATE": 1e-4, - "WEIGHT_DECAY": 1e-8, - "MOMENTUM": 0.9, - "IMG_SIZE": 512, - "SPHERES": 5, -} +from utils import ArtifactLog, TableLog if __name__ == "__main__": # setup logging logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - # setup wandb + # setup wandb, config loaded from config-default.yaml logger = WandbLogger( project="U-Net", - config=CONFIG, settings=wandb.Settings( code_dir="./src/", ), @@ -49,44 +26,38 @@ if __name__ == "__main__": # Create network model = UNetModule( - n_channels=CONFIG["N_CHANNELS"], - n_classes=CONFIG["N_CLASSES"], - batch_size=CONFIG["BATCH_SIZE"], - learning_rate=CONFIG["LEARNING_RATE"], - features=CONFIG["FEATURES"], + n_channels=wandb.config.N_CHANNELS, + n_classes=wandb.config.N_CLASSES, + batch_size=wandb.config.BATCH_SIZE, + learning_rate=wandb.config.LEARNING_RATE, + features=wandb.config.FEATURES, ) # load checkpoint - state_dict = torch.load("checkpoints/synth.pth") - state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()]) - model.load_state_dict(state_dict) + # state_dict = torch.load("checkpoints/synth.pth") + # state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()]) + # model.load_state_dict(state_dict) # log gradients and weights regularly logger.watch(model, log="all") - # create checkpoint callback - checkpoint_callback = ModelCheckpoint( - dirpath="checkpoints", - filename="model.ckpt", - monitor="val/dice", - ) - # Create the dataloaders datamodule = Spheres() # Create the trainer trainer = pl.Trainer( - max_epochs=CONFIG["EPOCHS"], - accelerator=CONFIG["DEVICE"], - benchmark=CONFIG["BENCHMARK"], + max_epochs=wandb.config.EPOCHS, + accelerator=wandb.config.DEVICE, + benchmark=wandb.config.BENCHMARK, # profiler="simple", # precision=16, logger=logger, log_every_n_steps=1, - val_check_interval=25, - callbacks=[RichProgressBar(), checkpoint_callback, TableLog()], + val_check_interval=100, + callbacks=[RichProgressBar(), ArtifactLog(), TableLog()], ) + # actually train the model trainer.fit(model=model, datamodule=datamodule) # stop wandb diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 969f4f0..2d4521b 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,2 +1,2 @@ -from .callback import TableLog +from .callback import ArtifactLog, TableLog from .paste import RandomPaste diff --git a/src/utils/callback.py b/src/utils/callback.py index 8a4879d..20497f2 100644 --- a/src/utils/callback.py +++ b/src/utils/callback.py @@ -1,5 +1,6 @@ +import numpy as np +import torch from pytorch_lightning.callbacks import Callback -from torch import tensor import wandb @@ -22,35 +23,36 @@ class TableLog(Callback): def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): # unpacking - images, ground_truth = batch - metrics, predictions = outputs + if batch_idx == 0: + images, ground_truth = batch + metrics, predictions = outputs - for i, (img, mask, pred, pred_bin) in enumerate( - zip( - images.cpu(), - ground_truth.cpu(), - predictions["linear"].cpu(), - predictions["binary"].cpu().squeeze(1).int().numpy(), - ) - ): - self.rows.append( - [ - i, - wandb.Image(img), - wandb.Image(mask), - wandb.Image( - pred, - masks={ - "predictions": { - "mask_data": pred_bin, - "class_labels": class_labels, + for i, (img, mask, pred, pred_bin) in enumerate( + zip( + images.cpu(), + ground_truth.cpu(), + predictions["linear"].cpu(), + predictions["binary"].cpu().squeeze(1).int().numpy(), + ) + ): + self.rows.append( + [ + i, + wandb.Image(img), + wandb.Image(mask), + wandb.Image( + pred, + masks={ + "predictions": { + "mask_data": pred_bin, + "class_labels": class_labels, + }, }, - }, - ), - metrics["dice"], - metrics["dice_bin"], - ] - ) + ), + metrics["dice"], + metrics["dice_bin"], + ] + ) def on_validation_epoch_end(self, trainer, pl_module): # log table @@ -62,3 +64,30 @@ class TableLog(Callback): ) } ) + + +class ArtifactLog(Callback): + def on_validation_epoch_start(self, trainer, pl_module): + self.dices = [] + self.best = 1 + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + # unpacking + metrics, _ = outputs + self.dices.append(metrics["dice"].cpu()) + + def on_validation_epoch_end(self, trainer, pl_module): + dice = np.mean(self.dices) + self.dices = [] + + if dice < self.best: + self.best = dice + + # create checkpoint + torch.save(self.state_dict(), "checkpoints/model.pth") + # trainer.save_checkpoint("example.ckpt") # TODO: change to .ckpt + + # create and log artifact + artifact = wandb.Artifact("pth", type="model") + artifact.add_file("checkpoints/model.pth") + wandb.run.log_artifact(artifact)