diff --git a/src/data/dataloader.py b/src/data/dataloader.py index c08f9c4..140a394 100644 --- a/src/data/dataloader.py +++ b/src/data/dataloader.py @@ -19,8 +19,13 @@ class Spheres(pl.LightningDataModule): def train_dataloader(self): transforms = A.Compose( [ - A.ToFloat(max_value=255), - ToTensorV2(), + A.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + max_pixel_value=255, + ), # [0, 255] -> [0.0, 1.0] normalized + # A.ToFloat(max_value=255), + ToTensorV2(), # HWC -> CHW ], bbox_params=A.BboxParams( format="pascal_voc", @@ -30,17 +35,18 @@ class Spheres(pl.LightningDataModule): ), ) - dataset = RealDataset(root="/media/disk1/lfainsin/TEST_tmp_mrcnn/", transforms=transforms) - print(f"len(dataset)={len(dataset)}") + dataset = RealDataset(root="/dev/shm/TEST_tmp_mrcnn/", transforms=transforms) dataset = Subset(dataset, list(range(len(dataset)))) # somehow this allows to better utilize the gpu + # dataset = Subset(dataset, list(range(20))) # somehow this allows to better utilize the gpu return DataLoader( dataset, - shuffle=True, + shuffle=False, + persistent_workers=True, prefetch_factor=wandb.config.PREFETCH_FACTOR, batch_size=wandb.config.TRAIN_BATCH_SIZE, - num_workers=wandb.config.WORKERS, pin_memory=wandb.config.PIN_MEMORY, + num_workers=wandb.config.WORKERS, collate_fn=collate_fn, ) diff --git a/src/train.py b/src/train.py index f7027bb..f0575db 100644 --- a/src/train.py +++ b/src/train.py @@ -7,7 +7,6 @@ from pytorch_lightning.loggers import WandbLogger import wandb from data import Spheres from mrcnn import MRCNNModule -from unet import UNetModule from utils import ArtifactLog, TableLog if __name__ == "__main__": @@ -16,7 +15,7 @@ if __name__ == "__main__": # setup wandb logger = WandbLogger( - project="U-Net", + project="Mask R-CNN", config="wandb.yaml", settings=wandb.Settings( code_dir="./src/", @@ -24,15 +23,9 @@ if __name__ == "__main__": ) # seed random generators - pl.seed_everything(69420, workers=True) - - # Create network - # model = UNetModule( - # n_channels=wandb.config.N_CHANNELS, - # n_classes=wandb.config.N_CLASSES, - # features=wandb.config.FEATURES, - # ) + pl.seed_everything(wandb.config.SEED, workers=True) + # Create Network model = MRCNNModule( hidden_layer_size=-1, n_classes=2, @@ -54,13 +47,13 @@ if __name__ == "__main__": max_epochs=wandb.config.EPOCHS, accelerator=wandb.config.DEVICE, benchmark=wandb.config.BENCHMARK, - deterministic=True, - precision=16, + deterministic=wandb.config.DETERMINISTIC, + precision=wandb.config.PRECISION, logger=logger, - log_every_n_steps=1, + log_every_n_steps=5, # val_check_interval=100, callbacks=[RichProgressBar(), ArtifactLog(), TableLog()], - # profiler="simple", + # profiler="advanced", num_sanity_val_steps=0, ) diff --git a/src/unet/__init__.py b/src/unet/__init__.py deleted file mode 100644 index 9abb37e..0000000 --- a/src/unet/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .module import UNetModule diff --git a/src/unet/blocks.py b/src/unet/blocks.py deleted file mode 100644 index 0df7f5f..0000000 --- a/src/unet/blocks.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Parts of the U-Net model.""" - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class DoubleConv(nn.Module): - """(convolution => [BN] => ReLU) * 2""" - - def __init__(self, in_channels, out_channels, mid_channels=None): - super().__init__() - - if not mid_channels: - mid_channels = out_channels - - self.double_conv = nn.Sequential( - # first convolution - nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(mid_channels), - nn.ReLU(inplace=True), - # second convolution - nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - ) - - def forward(self, x): - return self.double_conv(x) - - -class Down(nn.Module): - """Downscaling with maxpool then double conv""" - - def __init__(self, in_channels, out_channels): - super().__init__() - - self.maxpool_conv = nn.Sequential( - nn.MaxPool2d(2), - DoubleConv(in_channels, out_channels), - ) - - def forward(self, x): - return self.maxpool_conv(x) - - -class Up(nn.Module): - """Upscaling then double conv""" - - def __init__(self, in_channels, out_channels): - super().__init__() - - self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) - self.conv = DoubleConv(in_channels, out_channels) - - def forward(self, x1, x2): - x1 = self.up(x1) - - # input is CHW - diffY = x2.size()[2] - x1.size()[2] - diffX = x2.size()[3] - x1.size()[3] - diffY2 = torch.div(diffY, 2, rounding_mode="trunc") - diffX2 = torch.div(diffX, 2, rounding_mode="trunc") - - x1 = F.pad( - input=x1, - pad=[diffX2, diffX - diffX2, diffY2, diffY - diffY2], - ) - - x = torch.cat([x2, x1], dim=1) - - return self.conv(x) - - -class OutConv(nn.Module): - def __init__(self, in_channels, out_channels): - super(OutConv, self).__init__() - - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) - - def forward(self, x): - return self.conv(x) diff --git a/src/unet/model.py b/src/unet/model.py deleted file mode 100644 index a5b8e2a..0000000 --- a/src/unet/model.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Full assembly of the parts to form the complete network.""" - -import torch.nn as nn - -from .blocks import * - - -class UNet(nn.Module): - def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]): - super(UNet, self).__init__() - - # Network - self.inc = DoubleConv(n_channels, features[0]) - - self.downs = nn.ModuleList() - for i in range(len(features) - 1): - self.downs.append( - Down(*features[i : i + 2]), - ) - - self.ups = nn.ModuleList() - for i in range(len(features) - 1): - self.ups.append( - Up(*features[-1 - i : -3 - i : -1]), - ) - - self.outc = OutConv(features[0], n_classes) - - def forward(self, x): - skips = [] - - x = self.inc(x) - - for down in self.downs: - skips.append(x) - x = down(x) - - for up, skip in zip(self.ups, reversed(skips)): - x = up(x, skip) - - x = self.outc(x) - - return x diff --git a/src/unet/module.py b/src/unet/module.py deleted file mode 100644 index 43b09d0..0000000 --- a/src/unet/module.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Pytorch lightning wrapper for model.""" - -import itertools - -import pytorch_lightning as pl - -import wandb -from unet.model import UNet -from utils.dice import dice_loss - -from .blocks import * - -class_labels = { - 1: "sphere", -} - - -class UNetModule(pl.LightningModule): - def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]): - super(UNetModule, self).__init__() - - # Hyperparameters - self.n_channels = n_channels - self.n_classes = n_classes - - # log hyperparameters - self.save_hyperparameters() - - # Network - self.model = UNet(n_channels, n_classes, features) - - def forward(self, x): - return self.model(x) - - def shared_step(self, batch): - data, ground_truth = batch # unpacking - ground_truth = ground_truth.unsqueeze(1) # 1HW -> HW - - # forward pass, compute masks - prediction = self.model(data) - binary = (torch.sigmoid(prediction) > 0.5).half() - - # compute metrics (in dictionnary) - metrics = { - "dice": dice_loss(prediction, ground_truth), - "dice_bin": dice_loss(binary, ground_truth, logits=False), - "bce": F.binary_cross_entropy_with_logits(prediction, ground_truth), - "mae": torch.nn.functional.l1_loss(binary, ground_truth), - "accuracy": (ground_truth == binary).float().mean(), - } - - # wrap tensors in dictionnary - predictions = { - "linear": prediction, - "binary": binary, - } - - return metrics, predictions - - def training_step(self, batch, batch_idx): - # compute metrics - metrics, _ = self.shared_step(batch) - - # log metrics - self.log_dict(dict([(f"train/{key}", value) for key, value in metrics.items()])) - - return metrics["dice"] - - def validation_step(self, batch, batch_idx): - # compute metrics - metrics, predictions = self.shared_step(batch) - - # log metrics - self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()])) - - return metrics, predictions - - def validation_epoch_end(self, validation_outputs): - # unpacking - metricss = [v[0] for v in validation_outputs] - rowss = [v[1] for v in validation_outputs] - - # metrics flattening - metrics = { - "dice": torch.stack([d["dice"] for d in metricss]).mean(), - "dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(), - "bce": torch.stack([d["bce"] for d in metricss]).mean(), - "mae": torch.stack([d["mae"] for d in metricss]).mean(), - "accuracy": torch.stack([d["accuracy"] for d in metricss]).mean(), - } - - # log metrics - self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()])) - - # rows flattening - rows = list(itertools.chain.from_iterable(rowss)) - columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"] - - # log table - wandb.log( - { - "val/predictions": wandb.Table( - columns=columns, - data=rows, - ) - } - ) - - def configure_optimizers(self): - optimizer = torch.optim.RMSprop( - self.parameters(), - lr=wandb.config.LEARNING_RATE, - weight_decay=wandb.config.WEIGHT_DECAY, - momentum=wandb.config.MOMENTUM, - ) - - return optimizer diff --git a/wandb.yaml b/wandb.yaml index 738b412..e7da84d 100644 --- a/wandb.yaml +++ b/wandb.yaml @@ -17,11 +17,17 @@ AMP: PIN_MEMORY: value: True BENCHMARK: + value: True +DETERMINISTIC: value: False +PRECISION: + value: 16 +SEED: + value: 69420 DEVICE: value: gpu WORKERS: - value: 1 + value: 16 IMG_SIZE: value: 512 @@ -29,17 +35,17 @@ SPHERES: value: 3 EPOCHS: - value: 3 + value: 10 TRAIN_BATCH_SIZE: - value: 2 # 100 + value: 8 VAL_BATCH_SIZE: - value: 0 # 10 + value: 0 PREFETCH_FACTOR: - value: 1 + value: 2 LEARNING_RATE: - value: 1.0e-4 + value: 0.005 WEIGHT_DECAY: - value: 1.0e-8 + value: 0.0005 MOMENTUM: value: 0.9