From 4015dad4910c61f02bd6c8ed99ae9031932baaaa Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Mon, 11 Jul 2022 14:41:18 +0200 Subject: [PATCH] refactor: moved the table logging in a callback Former-commit-id: 37fa7b0da4556417ee1665d9f0375e71ce075958 [formerly 9314a3dfee09b085bb0d125d178c9f589532c6f5] Former-commit-id: 9c8bef3f5c6560f26512bc6586f1337b9d7985f3 --- src/data/dataloader.py | 13 ++++++ src/train.py | 22 ++++++---- src/unet/module.py | 93 ++++-------------------------------------- src/utils/__init__.py | 1 + src/utils/callback.py | 64 +++++++++++++++++++++++++++++ 5 files changed, 102 insertions(+), 91 deletions(-) create mode 100644 src/utils/callback.py diff --git a/src/data/dataloader.py b/src/data/dataloader.py index 14479ed..2257262 100644 --- a/src/data/dataloader.py +++ b/src/data/dataloader.py @@ -36,6 +36,19 @@ class Spheres(pl.LightningDataModule): pin_memory=wandb.config.PIN_MEMORY, ) + # dataset = LabeledDataset(image_dir="/home/lilian/data_disk/lfainsin/prerender/") + # dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) + # dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1))) + + # return DataLoader( + # dataset, + # shuffle=True, + # batch_size=8, + # prefetch_factor=8, + # num_workers=wandb.config.WORKERS, + # pin_memory=wandb.config.PIN_MEMORY, + # ) + def val_dataloader(self): dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1))) diff --git a/src/train.py b/src/train.py index d278141..6652d8a 100644 --- a/src/train.py +++ b/src/train.py @@ -1,17 +1,19 @@ import logging import pytorch_lightning as pl +import torch from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar from pytorch_lightning.loggers import WandbLogger import wandb from data import Spheres from unet import UNetModule +from utils import TableLog CONFIG = { "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", - "DIR_VALID_IMG": "//home/lilian/data_disk/lfainsin/test_split/", - "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres+real_split/", + "DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/test_batched_fast/", + "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres+real/", "FEATURES": [8, 16, 32, 64], "N_CHANNELS": 3, "N_CLASSES": 1, @@ -19,9 +21,9 @@ CONFIG = { "PIN_MEMORY": True, "BENCHMARK": True, "DEVICE": "gpu", - "WORKERS": 14, - "EPOCHS": 1, - "BATCH_SIZE": 16 * 3, + "WORKERS": 8, + "EPOCHS": 10, + "BATCH_SIZE": 16, "LEARNING_RATE": 1e-4, "WEIGHT_DECAY": 1e-8, "MOMENTUM": 0.9, @@ -54,12 +56,18 @@ if __name__ == "__main__": features=CONFIG["FEATURES"], ) + # load checkpoint + state_dict = torch.load("checkpoints/synth.pth") + state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()]) + model.load_state_dict(state_dict) + # log gradients and weights regularly logger.watch(model, log="all") # create checkpoint callback checkpoint_callback = ModelCheckpoint( dirpath="checkpoints", + filename="model.ckpt", monitor="val/dice", ) @@ -75,8 +83,8 @@ if __name__ == "__main__": # precision=16, logger=logger, log_every_n_steps=1, - val_check_interval=100, - callbacks=RichProgressBar(), + val_check_interval=25, + callbacks=[RichProgressBar(), checkpoint_callback, TableLog()], ) trainer.fit(model=model, datamodule=datamodule) diff --git a/src/unet/module.py b/src/unet/module.py index 7c07962..86bb107 100644 --- a/src/unet/module.py +++ b/src/unet/module.py @@ -52,105 +52,30 @@ class UNetModule(pl.LightningModule): } # wrap tensors in dictionnary - tensors = { - "data": data, - "ground_truth": ground_truth, - "prediction": prediction, + predictions = { + "linear": prediction, "binary": binary, } - return metrics, tensors + return metrics, predictions def training_step(self, batch, batch_idx): # compute metrics - metrics, tensors = self.shared_step(batch) + metrics, _ = self.shared_step(batch) # log metrics self.log_dict(dict([(f"train/{key}", value) for key, value in metrics.items()])) - if batch_idx == 5000: - rows = [] - columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"] - for i, (img, mask, pred, pred_bin) in enumerate( - zip( # TODO: use comprehension list to zip the dictionnary - tensors["data"].cpu(), - tensors["ground_truth"].cpu(), - tensors["prediction"].cpu(), - tensors["binary"] - .cpu() - .squeeze(1) - .int() - .numpy(), # TODO: check if .functions can be moved elsewhere - ) - ): - rows.append( - [ - i, - wandb.Image(img), - wandb.Image(mask), - wandb.Image( - pred, - masks={ - "predictions": { - "mask_data": pred_bin, - "class_labels": class_labels, - }, - }, - ), - metrics["dice"], - metrics["dice_bin"], - ] - ) - - # log table - wandb.log( - { - "train/predictions": wandb.Table( - columns=columns, - data=rows, - ) - } - ) - return metrics["dice"] def validation_step(self, batch, batch_idx): - metrics, tensors = self.shared_step(batch) + # compute metrics + metrics, predictions = self.shared_step(batch) - rows = [] - if batch_idx % 50 == 0 or metrics["dice"] > 0.9: - for i, (img, mask, pred, pred_bin) in enumerate( - zip( # TODO: use comprehension list to zip the dictionnary - tensors["data"].cpu(), - tensors["ground_truth"].cpu(), - tensors["prediction"].cpu(), - tensors["binary"] - .cpu() - .squeeze(1) - .int() - .numpy(), # TODO: check if .functions can be moved elsewhere - ) - ): - rows.append( - [ - i, - wandb.Image(img), - wandb.Image(mask), - wandb.Image( - pred, - masks={ - "predictions": { - "mask_data": pred_bin, - "class_labels": class_labels, - }, - }, - ), - metrics["dice"], - metrics["dice_bin"], - ] - ) + # log metrics + self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()])) - return metrics, rows + return metrics, predictions def validation_epoch_end(self, validation_outputs): # unpacking diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 546b2fe..969f4f0 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1 +1,2 @@ +from .callback import TableLog from .paste import RandomPaste diff --git a/src/utils/callback.py b/src/utils/callback.py new file mode 100644 index 0000000..8a4879d --- /dev/null +++ b/src/utils/callback.py @@ -0,0 +1,64 @@ +from pytorch_lightning.callbacks import Callback +from torch import tensor + +import wandb + +columns = [ + "ID", + "image", + "ground truth", + "prediction", + "dice", + "dice_bin", +] +class_labels = { + 1: "sphere", +} + + +class TableLog(Callback): + def on_validation_epoch_start(self, trainer, pl_module): + self.rows = [] + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + # unpacking + images, ground_truth = batch + metrics, predictions = outputs + + for i, (img, mask, pred, pred_bin) in enumerate( + zip( + images.cpu(), + ground_truth.cpu(), + predictions["linear"].cpu(), + predictions["binary"].cpu().squeeze(1).int().numpy(), + ) + ): + self.rows.append( + [ + i, + wandb.Image(img), + wandb.Image(mask), + wandb.Image( + pred, + masks={ + "predictions": { + "mask_data": pred_bin, + "class_labels": class_labels, + }, + }, + ), + metrics["dice"], + metrics["dice_bin"], + ] + ) + + def on_validation_epoch_end(self, trainer, pl_module): + # log table + wandb.log( + { + "val/predictions": wandb.Table( + columns=columns, + data=self.rows, + ) + } + )