diff --git a/src/train.py b/src/train.py index 9ddac8e..e8f3594 100644 --- a/src/train.py +++ b/src/train.py @@ -1,16 +1,16 @@ import logging import albumentations as A +import pytorch_lightning as pl import torch import yaml from albumentations.pytorch import ToTensorV2 +from pytorch_lightning.loggers import WandbLogger from torch.utils.data import DataLoader -from tqdm import tqdm import wandb from src.utils.dataset import SphereDataset from unet import UNet -from utils.dice import dice_coeff from utils.paste import RandomPaste class_labels = { @@ -22,7 +22,7 @@ if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") # setup wandb - wandb.init( + logger = WandbLogger( project="U-Net", config=dict( DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/", @@ -36,7 +36,7 @@ if __name__ == "__main__": AMP=True, PIN_MEMORY=True, BENCHMARK=True, - DEVICE="cuda", + DEVICE="gpu", WORKERS=8, EPOCHS=5, BATCH_SIZE=16, @@ -51,18 +51,17 @@ if __name__ == "__main__": ), ) - # create device - device = torch.device(wandb.config.DEVICE) - - # enable cudnn benchmarking - torch.backends.cudnn.benchmark = wandb.config.BENCHMARK + # seed random generators + pl.seed_everything(69420, workers=True) # 0. Create network net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES) + + # log the number of parameters of the model wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad) - # transfer network to device - net.to(device=device) + # log gradients and weights regularly + logger.watch(net, log="all") # 1. Create transforms tf_train = A.Compose( @@ -121,244 +120,38 @@ if __name__ == "__main__": pin_memory=wandb.config.PIN_MEMORY, ) - # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp - optimizer = torch.optim.RMSprop( - net.parameters(), - lr=wandb.config.LEARNING_RATE, - weight_decay=wandb.config.WEIGHT_DECAY, - momentum=wandb.config.MOMENTUM, + # 4. Create the trainer + trainer = pl.Trainer( + max_epochs=wandb.config.EPOCHS, + accelerator="gpu", + precision=16, + auto_scale_batch_size="binsearch", + benchmark=wandb.config.BENCHMARK, + val_check_interval=100, ) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) - grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP) - criterion = torch.nn.BCEWithLogitsLoss() - - # save model.onxx - dummy_input = torch.randn( - 1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True - ).to(device) - torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx") - artifact = wandb.Artifact("onnx", type="model") - artifact.add_file("checkpoints/model-0.onnx") - wandb.run.log_artifact(artifact) - - # log gradients and weights four time per epoch - wandb.watch(net, criterion, log_freq=100) # print the config logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") - # wandb init log - wandb.log( - { - "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], - }, - commit=False, - ) + # # wandb init log + # wandb.log( + # { + # "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], + # }, + # commit=False, + # ) try: - for epoch in range(1, wandb.config.EPOCHS + 1): - with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar: - - # Training round - for step, (images, true_masks) in enumerate(train_loader): - assert images.shape[1] == net.n_channels, ( - f"Network has been defined with {net.n_channels} input channels, " - f"but loaded images have {images.shape[1]} channels. Please check that " - "the images are loaded correctly." - ) - - # transfer images to device - images = images.to(device=device) - true_masks = true_masks.unsqueeze(1).to(device=device) - - # forward - with torch.cuda.amp.autocast(enabled=wandb.config.AMP): - pred_masks = net(images) - train_loss = criterion(pred_masks, true_masks) - - # backward - optimizer.zero_grad(set_to_none=True) - grad_scaler.scale(train_loss).backward() - grad_scaler.step(optimizer) - grad_scaler.update() - - # compute metrics - pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float() - accuracy = (true_masks == pred_masks_bin).float().mean() - dice = dice_coeff(pred_masks_bin, true_masks) - mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks) - - # update tqdm progress bar - pbar.update(images.shape[0]) - pbar.set_postfix(**{"loss": train_loss.item()}) - - # log metrics - wandb.log( - { - "epoch": epoch - 1 + step / len(train_loader), - "train/accuracy": accuracy, - "train/bce": train_loss, - "train/dice": dice, - "train/mae": mae, - } - ) - - if step and (step % 250 == 0 or step == len(train_loader)): - # Evaluation round - net.eval() - accuracy = 0 - val_loss = 0 - dice = 0 - mae = 0 - with tqdm(val_loader, total=len(ds_valid), desc="val.", unit="img", leave=False) as pbar2: - for images, masks_true in val_loader: - - # transfer images to device - images = images.to(device=device) - masks_true = masks_true.unsqueeze(1).to(device=device) - - # forward - with torch.inference_mode(): - masks_pred = net(images) - - # compute metrics - val_loss += criterion(masks_pred, masks_true) - masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true) - accuracy += (masks_true == masks_pred_bin).float().mean() - dice += dice_coeff(masks_pred_bin, masks_true) - - # update progress bar - pbar2.update(images.shape[0]) - - accuracy /= len(val_loader) - val_loss /= len(val_loader) - dice /= len(val_loader) - mae /= len(val_loader) - - # save the last validation batch to table - table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) - for i, (img, mask, pred, pred_bin) in enumerate( - zip( - images.cpu(), - masks_true.cpu(), - masks_pred.cpu(), - masks_pred_bin.cpu().squeeze(1).int().numpy(), - ) - ): - table.add_data( - i, - wandb.Image(img), - wandb.Image(mask), - wandb.Image( - pred, - masks={ - "predictions": { - "mask_data": pred_bin, - "class_labels": class_labels, - }, - }, - ), - ) - - # log validation metrics - wandb.log( - { - "val/predictions": table, - "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], - "val/accuracy": accuracy, - "val/bce": val_loss, - "val/dice": dice, - "val/mae": mae, - }, - commit=False, - ) - - # update hyperparameters - net.train() - scheduler.step(dice) - - # export model to onnx format when validation ends - dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) - torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}-{step}.onnx") - artifact = wandb.Artifact("onnx", type="model") - artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx") - wandb.run.log_artifact(artifact) - - # testing round - net.eval() - accuracy = 0 - val_loss = 0 - dice = 0 - mae = 0 - with tqdm(test_loader, total=len(ds_test), desc="test", unit="img", leave=False) as pbar3: - for images, masks_true in test_loader: - - # transfer images to device - images = images.to(device=device) - masks_true = masks_true.unsqueeze(1).to(device=device) - - # forward - with torch.inference_mode(): - masks_pred = net(images) - - # compute metrics - val_loss += criterion(masks_pred, masks_true) - masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true) - accuracy += (masks_true == masks_pred_bin).float().mean() - dice += dice_coeff(masks_pred_bin, masks_true) - - # update progress bar - pbar3.update(images.shape[0]) - - accuracy /= len(test_loader) - val_loss /= len(test_loader) - dice /= len(test_loader) - mae /= len(test_loader) - - # save the last validation batch to table - table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) - for i, (img, mask, pred, pred_bin) in enumerate( - zip( - images.cpu(), - masks_true.cpu(), - masks_pred.cpu(), - masks_pred_bin.cpu().squeeze(1).int().numpy(), - ) - ): - table.add_data( - i, - wandb.Image(img), - wandb.Image(mask), - wandb.Image( - pred, - masks={ - "predictions": { - "mask_data": pred_bin, - "class_labels": class_labels, - }, - }, - ), - ) - - # log validation metrics - wandb.log( - { - "test/predictions": table, - "test/accuracy": accuracy, - "test/bce": val_loss, - "test/dice": dice, - "test/mae": mae, - }, - commit=False, - ) - - # stop wandb - wandb.run.finish() - + trainer.fit( + model=net, + train_dataloaders=train_loader, + val_dataloaders=val_loader, + test_dataloaders=test_loader, + accelerator=wandb.config.DEVICE, + ) except KeyboardInterrupt: torch.save(net.state_dict(), "INTERRUPTED.pth") raise -# sapin de noel + # stop wandb + wandb.run.finish() diff --git a/src/unet/model.py b/src/unet/model.py index 08d2807..378b407 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -1,9 +1,21 @@ """ Full assembly of the parts to form the complete network """ +from xmlrpc.server import list_public_methods + +import numpy as np +import pytorch_lightning as pl + +import wandb +from utils.dice import dice_coeff + from .blocks import * +class_labels = { + 1: "sphere", +} -class UNet(nn.Module): + +class UNet(pl.LightningModule): def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]): super(UNet, self).__init__() self.n_channels = n_channels @@ -26,7 +38,6 @@ class UNet(nn.Module): self.outc = OutConv(features[0], n_classes) def forward(self, x): - skips = [] x = self.inc(x) @@ -41,3 +52,158 @@ class UNet(nn.Module): x = self.outc(x) return x + + @staticmethod + def save_to_table(images, masks_true, masks_pred, masks_pred_bin, log_key): + table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) + + for i, (img, mask, pred, pred_bin) in enumerate( + zip( + images.cpu(), + masks_true.cpu(), + masks_pred.cpu(), + masks_pred_bin.cpu().squeeze(1).int().numpy(), + ) + ): + table.add_data( + i, + wandb.Image(img), + wandb.Image(mask), + wandb.Image( + pred, + masks={ + "predictions": { + "mask_data": pred_bin, + "class_labels": class_labels, + }, + }, + ), + ) + + wandb.log( + { + log_key: table, + } + ) + + def training_step(self, batch, batch_idx): + # unpacking + images, masks_true = batch + masks_true = masks_true.unsqueeze(1) + masks_pred = self(images) + masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + + # compute metrics + loss = F.cross_entropy(masks_pred, masks_true) + mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) + accuracy = (masks_true == masks_pred_bin).float().mean() + dice = dice_coeff(masks_pred_bin, masks_true) + + wandb.log( + { + "train/accuracy": accuracy, + "train/bce": loss, + "train/dice": dice, + "train/mae": mae, + } + ) + + return loss, dice, accuracy, mae + + def validation_step(self, batch, batch_idx): + # unpacking + images, masks_true = batch + masks_true = masks_true.unsqueeze(1) + masks_pred = self(images) + masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + + # compute metrics + loss = F.cross_entropy(masks_pred, masks_true) + mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) + accuracy = (masks_true == masks_pred_bin).float().mean() + dice = dice_coeff(masks_pred_bin, masks_true) + + if batch_idx == 0: + self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions") + + return loss, dice, accuracy, mae + + def validation_step_end(self, validation_outputs): + # unpacking + loss, dice, accuracy, mae = validation_outputs + optimizer = self.optimizers[0] + learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] + + wandb.log( + { + "train/learning_rate": learning_rate, + "val/accuracy": accuracy, + "val/bce": loss, + "val/dice": dice, + "val/mae": mae, + } + ) + + # export model to onnx + dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True) + torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx") + artifact = wandb.Artifact("onnx", type="model") + artifact.add_file(f"checkpoints/model.onnx") + wandb.run.log_artifact(artifact) + + def test_step(self, batch, batch_idx): + # unpacking + images, masks_true = batch + masks_true = masks_true.unsqueeze(1) + masks_pred = self(images) + masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + + # compute metrics + loss = F.cross_entropy(masks_pred, masks_true) + mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) + accuracy = (masks_true == masks_pred_bin).float().mean() + dice = dice_coeff(masks_pred_bin, masks_true) + + if batch_idx == 0: + self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions") + + return loss, dice, accuracy, mae + + def test_step_end(self, test_outputs): + # unpacking + list_loss, list_dice, list_accuracy, list_mae = test_outputs + + # averaging + loss = np.mean(list_loss) + dice = np.mean(list_dice) + accuracy = np.mean(list_accuracy) + mae = np.mean(list_mae) + + # get learning rate + optimizer = self.optimizers[0] + learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] + + wandb.log( + { + "train/learning_rate": learning_rate, + "val/accuracy": accuracy, + "val/bce": loss, + "val/dice": dice, + "val/mae": mae, + } + ) + + def configure_optimizers(self): + optimizer = torch.optim.RMSprop( + self.parameters(), + lr=wandb.config.LEARNING_RATE, + weight_decay=wandb.config.WEIGHT_DECAY, + momentum=wandb.config.MOMENTUM, + ) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + "max", + patience=2, + ) + + return optimizer, scheduler