diff --git a/src/comp.ipynb.REMOVED.git-id b/src/comp.ipynb.REMOVED.git-id index 21ecac1..e1a4cf5 100644 --- a/src/comp.ipynb.REMOVED.git-id +++ b/src/comp.ipynb.REMOVED.git-id @@ -1 +1 @@ -0f3136c724eea42fdf1ee15e721ef33604e9a46d \ No newline at end of file +ac8ff07f541ae6d7cba729b20e0d04654c6018c9 \ No newline at end of file diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000..81370c5 --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1 @@ +from .dataloader import SyntheticSphere diff --git a/src/data/dataloader.py b/src/data/dataloader.py new file mode 100644 index 0000000..f575d3a --- /dev/null +++ b/src/data/dataloader.py @@ -0,0 +1,50 @@ +import albumentations as A +import pytorch_lightning as pl +from albumentations.pytorch import ToTensorV2 +from torch.utils.data import DataLoader + +import wandb +from utils import RandomPaste + +from .dataset import SphereDataset + + +class SyntheticSphere(pl.LightningDataModule): + def __init__(self): + super().__init__() + + def train_dataloader(self): + tf_train = A.Compose( + [ + A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), + A.Flip(), + A.ColorJitter(), + RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE), + A.GaussianBlur(), + A.ISONoise(), + A.ToFloat(max_value=255), + ToTensorV2(), + ], + ) + + ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) + # ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) + + return DataLoader( + ds_train, + shuffle=True, + batch_size=wandb.config.BATCH_SIZE, + num_workers=wandb.config.WORKERS, + pin_memory=wandb.config.PIN_MEMORY, + ) + + def val_dataloader(self): + ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG) + + return DataLoader( + ds_valid, + shuffle=False, + batch_size=1, + num_workers=wandb.config.WORKERS, + pin_memory=wandb.config.PIN_MEMORY, + ) diff --git a/src/utils/dataset.py b/src/data/dataset.py similarity index 100% rename from src/utils/dataset.py rename to src/data/dataset.py diff --git a/src/train.py b/src/train.py index 84ccb87..8e34144 100644 --- a/src/train.py +++ b/src/train.py @@ -1,7 +1,6 @@ import logging import pytorch_lightning as pl -import torch from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.loggers import WandbLogger @@ -57,13 +56,17 @@ if __name__ == "__main__": # log gradients and weights regularly logger.watch(net, log="all") + # create checkpoint callback + checkpoint_callback = pl.ModelCheckpoint( + dirpath="checkpoints", + monitor="val/dice", + ) + # Create the trainer trainer = pl.Trainer( max_epochs=CONFIG["EPOCHS"], accelerator=CONFIG["DEVICE"], # precision=16, - # auto_scale_batch_size="binsearch", - # auto_lr_find=True, benchmark=CONFIG["BENCHMARK"], val_check_interval=100, callbacks=RichProgressBar(), @@ -71,12 +74,7 @@ if __name__ == "__main__": log_every_n_steps=1, ) - try: - trainer.tune(net) - trainer.fit(model=net) - except KeyboardInterrupt: - torch.save(net.state_dict(), "INTERRUPTED.pth") - raise + trainer.fit(model=net) # stop wandb wandb.run.finish() diff --git a/src/unet/__init__.py b/src/unet/__init__.py index ed74c60..9abb37e 100644 --- a/src/unet/__init__.py +++ b/src/unet/__init__.py @@ -1 +1 @@ -from .model import UNet +from .module import UNetModule diff --git a/src/unet/model.py b/src/unet/model.py index 4556640..bddfb14 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -1,37 +1,14 @@ -""" Full assembly of the parts to form the complete network """ +"""Full assembly of the parts to form the complete network.""" -import itertools - -import albumentations as A -import pytorch_lightning as pl -from albumentations.pytorch import ToTensorV2 -from torch.utils.data import DataLoader - -import wandb -from src.utils.dataset import SphereDataset -from utils.dice import dice_loss -from utils.paste import RandomPaste +import torch.nn as nn from .blocks import * -class_labels = { - 1: "sphere", -} - -class UNet(pl.LightningModule): - def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]): +class UNet(nn.Module): + def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]): super(UNet, self).__init__() - # Hyperparameters - self.n_channels = n_channels - self.n_classes = n_classes - self.learning_rate = learning_rate - self.batch_size = batch_size - - # log hyperparameters - self.save_hyperparameters() - # Network self.inc = DoubleConv(n_channels, features[0]) @@ -66,224 +43,3 @@ class UNet(pl.LightningModule): x = self.outc(x) return x - - def train_dataloader(self): - tf_train = A.Compose( - [ - A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), - A.Flip(), - A.ColorJitter(), - RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE), - A.GaussianBlur(), - A.ISONoise(), - A.ToFloat(max_value=255), - ToTensorV2(), - ], - ) - - ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) - # ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) - - return DataLoader( - ds_train, - batch_size=self.batch_size, - shuffle=True, - num_workers=wandb.config.WORKERS, - pin_memory=wandb.config.PIN_MEMORY, - ) - - def val_dataloader(self): - ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG) - - return DataLoader( - ds_valid, - shuffle=False, - batch_size=1, - num_workers=wandb.config.WORKERS, - pin_memory=wandb.config.PIN_MEMORY, - ) - - def training_step(self, batch, batch_idx): - # unpacking - images, masks_true = batch - masks_true = masks_true.unsqueeze(1) - - # forward pass - masks_pred = self(images) - - # compute metrics - bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true) - dice = dice_loss(masks_pred, masks_true) - - masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False) - mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) - accuracy = (masks_true == masks_pred_bin).float().mean() - - self.log_dict( - { - "train/accuracy": accuracy, - "train/dice": dice, - "train/dice_bin": dice_bin, - "train/bce": bce, - "train/mae": mae, - }, - ) - - if batch_idx == 22000: - rows = [] - columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"] - for i, (img, mask, pred, pred_bin) in enumerate( - zip( - images.cpu(), - masks_true.cpu(), - masks_pred.cpu(), - masks_pred_bin.cpu().squeeze(1).int().numpy(), - ) - ): - rows.append( - [ - i, - wandb.Image(img), - wandb.Image(mask), - wandb.Image( - pred, - masks={ - "predictions": { - "mask_data": pred_bin, - "class_labels": class_labels, - }, - }, - ), - dice, - dice_bin, - ] - ) - - # logging - try: # required by autofinding, logger replaced by dummy - self.logger.log_table( - key="train/predictions", - columns=columns, - data=rows, - ) - except: - pass - - return dict( - accuracy=accuracy, - loss=dice, - bce=bce, - mae=mae, - ) - - def validation_step(self, batch, batch_idx): - # unpacking - images, masks_true = batch - masks_true = masks_true.unsqueeze(1) - - # forward pass - masks_pred = self(images) - - # compute metrics - bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true) - dice = dice_loss(masks_pred, masks_true) - - masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False) - mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) - accuracy = (masks_true == masks_pred_bin).float().mean() - - rows = [] - if batch_idx % 50 == 0 or dice > 0.9: - for i, (img, mask, pred, pred_bin) in enumerate( - zip( - images.cpu(), - masks_true.cpu(), - masks_pred.cpu(), - masks_pred_bin.cpu().squeeze(1).int().numpy(), - ) - ): - rows.append( - [ - i, - wandb.Image(img), - wandb.Image(mask), - wandb.Image( - pred, - masks={ - "predictions": { - "mask_data": pred_bin, - "class_labels": class_labels, - }, - }, - ), - dice, - dice_bin, - ] - ) - - return dict( - accuracy=accuracy, - loss=dice, - dice_bin=dice_bin, - bce=bce, - mae=mae, - table_rows=rows, - ) - - def validation_epoch_end(self, validation_outputs): - # matrics unpacking - accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean() - dice_bin = torch.stack([d["dice_bin"] for d in validation_outputs]).mean() - loss = torch.stack([d["loss"] for d in validation_outputs]).mean() - bce = torch.stack([d["bce"] for d in validation_outputs]).mean() - mae = torch.stack([d["mae"] for d in validation_outputs]).mean() - - # table unpacking - columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"] - rowss = [d["table_rows"] for d in validation_outputs] - rows = list(itertools.chain.from_iterable(rowss)) - - # logging - try: # required by autofinding, logger replaced by dummy - self.logger.log_table( - key="val/predictions", - columns=columns, - data=rows, - ) - except: - pass - - self.log_dict( - { - "val/accuracy": accuracy, - "val/dice": loss, - "val/dice_bin": dice_bin, - "val/bce": bce, - "val/mae": mae, - } - ) - - # export model to pth - torch.save(self.state_dict(), f"checkpoints/model.pth") - artifact = wandb.Artifact("pth", type="model") - artifact.add_file("checkpoints/model.pth") - wandb.run.log_artifact(artifact) - - # export model to onnx - dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True) - torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx") - artifact = wandb.Artifact("onnx", type="model") - artifact.add_file("checkpoints/model.onnx") - wandb.run.log_artifact(artifact) - - def configure_optimizers(self): - optimizer = torch.optim.RMSprop( - self.parameters(), - lr=self.learning_rate, - weight_decay=wandb.config.WEIGHT_DECAY, - momentum=wandb.config.MOMENTUM, - ) - - return optimizer diff --git a/src/unet/module.py b/src/unet/module.py new file mode 100644 index 0000000..a08a383 --- /dev/null +++ b/src/unet/module.py @@ -0,0 +1,193 @@ +"""Pytorch lightning wrapper for model.""" + +import itertools + +import pytorch_lightning as pl + +import wandb +from unet.model import UNet +from utils.dice import dice_loss + +from .blocks import * + +class_labels = { + 1: "sphere", +} + + +class UNetModule(pl.LightningModule): + def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]): + super(UNetModule, self).__init__() + + # Hyperparameters + self.n_channels = n_channels + self.n_classes = n_classes + self.learning_rate = learning_rate + self.batch_size = batch_size + + # log hyperparameters + self.save_hyperparameters() + + # Network + self.model = UNet(n_channels, n_classes, features) + + def forward(self, x): + return self.model(x) + + def shared_step(self, batch): + data, ground_truth = batch # unpacking + ground_truth = ground_truth.unsqueeze(1) # 1HW -> HW + + # forward pass, compute masks + prediction = self.model(data) + binary = (torch.sigmoid(prediction) > 0.5).float() # TODO: check if float necessary + + # compute metrics (in dictionnary) + metrics = { + "dice": dice_loss(prediction, ground_truth), + "dice_bin": dice_loss(binary, ground_truth, logits=False), + "bce": F.binary_cross_entropy_with_logits(prediction, ground_truth), + "mae": torch.nn.functional.l1_loss(binary, ground_truth), + "accuracy": (ground_truth == binary).float().mean(), + } + + # wrap tensors in dictionnary + tensors = { + "data": data, + "ground_truth": ground_truth, + "prediction": prediction, + "binary": binary, + } + + return metrics, tensors + + def training_step(self, batch, batch_idx): + # compute metrics + metrics, tensors = self.shared_step(batch) + + # log metrics + self.log_dict(dict([(f"train/{key}", value) for key, value in metrics.items()])) + + if batch_idx == 5000: + rows = [] + columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"] + for i, (img, mask, pred, pred_bin) in enumerate( + zip( # TODO: use comprehension list to zip the dictionnary + tensors["images"].cpu(), + tensors["ground_truth"].cpu(), + tensors["prediction"].cpu(), + tensors["binary"] + .cpu() + .squeeze(1) + .int() + .numpy(), # TODO: check if .functions can be moved elsewhere + ) + ): + rows.append( + [ + i, + wandb.Image(img), + wandb.Image(mask), + wandb.Image( + pred, + masks={ + "predictions": { + "mask_data": pred_bin, + "class_labels": class_labels, + }, + }, + ), + metrics["dice"], + metrics["dice_bin"], + ] + ) + + # log table + wandb.log( + { + "train/predictions": wandb.Table( + columns=columns, + data=rows, + ) + } + ) + + return metrics["dice"] + + def validation_step(self, batch, batch_idx): + metrics, tensors = self.shared_step(batch) + + rows = [] + if batch_idx % 50 == 0 or metrics["dice"] > 0.9: + for i, (img, mask, pred, pred_bin) in enumerate( + zip( # TODO: use comprehension list to zip the dictionnary + tensors["images"].cpu(), + tensors["ground_truth"].cpu(), + tensors["prediction"].cpu(), + tensors["binary"] + .cpu() + .squeeze(1) + .int() + .numpy(), # TODO: check if .functions can be moved elsewhere + ) + ): + rows.append( + [ + i, + wandb.Image(img), + wandb.Image(mask), + wandb.Image( + pred, + masks={ + "predictions": { + "mask_data": pred_bin, + "class_labels": class_labels, + }, + }, + ), + metrics["dice"], + metrics["dice_bin"], + ] + ) + + return metrics + + def validation_epoch_end(self, validation_outputs): + # unpacking + metricss, rowss = validation_outputs + + # metrics flattening + metrics = { + "dice": torch.stack([d["dice"] for d in metricss]).mean(), + "dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(), + "bce": torch.stack([d["bce"] for d in metricss]).mean(), + "mae": torch.stack([d["mae"] for d in metricss]).mean(), + "accuracy": torch.stack([d["accuracy"] for d in validation_outputs]).mean(), + } + + # log metrics + self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()])) + + # rows flattening + rows = list(itertools.chain.from_iterable(rowss)) + columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"] + + # log table + wandb.log( + { + "val/predictions": wandb.Table( + columns=columns, + data=rows, + ) + } + ) + + def configure_optimizers(self): + optimizer = torch.optim.RMSprop( + self.parameters(), + lr=self.learning_rate, + weight_decay=wandb.config.WEIGHT_DECAY, + momentum=wandb.config.MOMENTUM, + ) + + return optimizer diff --git a/src/utils/__init__.py b/src/utils/__init__.py index e69de29..546b2fe 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -0,0 +1 @@ +from .paste import RandomPaste