From 7bdac6583bf52dbc91e3b2cb772bea069adb1334 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Thu, 30 Jun 2022 23:28:38 +0200 Subject: [PATCH] feat: binarized the masks + lots of new metrics to fix Former-commit-id: c840d14f722503d241f6bb6d899630ad6345aca0 [formerly e435a21234620add4f0e4e269a4141e5c1508cd9] Former-commit-id: 8006af185fd68cc88b2305a02513106c16758d77 --- comp.ipynb.REMOVED.git-id | 2 +- src/evaluate.py | 45 ------------------------- src/train.py | 70 ++++++++++++++++++++++++++++++--------- src/utils/dataset.py | 7 ++-- src/utils/paste.py | 7 +++- 5 files changed, 67 insertions(+), 64 deletions(-) delete mode 100644 src/evaluate.py diff --git a/comp.ipynb.REMOVED.git-id b/comp.ipynb.REMOVED.git-id index b1ef6f5..b439b71 100644 --- a/comp.ipynb.REMOVED.git-id +++ b/comp.ipynb.REMOVED.git-id @@ -1 +1 @@ -fb39f9a23b728fadb88ce579f78bb419ff0eaab6 \ No newline at end of file +9cbd3cff7e664a80a5a1fa1404898b7bba3cae0d \ No newline at end of file diff --git a/src/evaluate.py b/src/evaluate.py deleted file mode 100644 index cf39e98..0000000 --- a/src/evaluate.py +++ /dev/null @@ -1,45 +0,0 @@ -import numpy as np -import torch -from tqdm import tqdm - -import wandb -from src.utils.dice import dice_coeff - -class_labels = { - 1: "sphere", -} - - -def evaluate(net, dataloader, device): - net.eval() - num_val_batches = len(dataloader) - dice_score = 0 - - # iterate over the validation set - with tqdm(dataloader, total=len(dataloader.dataset), desc="val", unit="img", leave=False) as pbar: - for images, masks_true in dataloader: - # move images and labels to correct device - images = images.to(device=device) - masks_true = masks_true.unsqueeze(1).float().to(device=device) - - # forward, predict the mask - with torch.inference_mode(): - masks_pred = net(images) - masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - - # compute the Dice score - dice_score += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False) - - # update progress bar - pbar.update(images.shape[0]) - - # save some images to wandb - table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) - for i, (img, mask, pred) in enumerate(zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu"))): - table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred)) - wandb.log({"predictions_table": table}) - - net.train() - - # Fixes a potential division by zero error - return dice_score / num_val_batches if num_val_batches else dice_score diff --git a/src/train.py b/src/train.py index 5feb01c..8d4dab1 100644 --- a/src/train.py +++ b/src/train.py @@ -8,9 +8,9 @@ from torch.utils.data import DataLoader from tqdm import tqdm import wandb -from evaluate import evaluate from src.utils.dataset import SphereDataset from unet import UNet +from utils.dice import dice_coeff from utils.paste import RandomPaste @@ -22,7 +22,7 @@ def main(): wandb.init( project="U-Net", config=dict( - DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017", + DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/smolval2017", DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/", DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/", DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/", @@ -51,7 +51,7 @@ def main(): # 0. Create network net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES) wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad) - wandb.watch(net, log_freq=100) + wandb.watch(net, log_freq=100) # TODO: 1/4 epochs # transfer network to device net.to(device=device) @@ -110,10 +110,6 @@ def main(): grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP) criterion = torch.nn.BCEWithLogitsLoss() - # accuracy stuff - mse = torch.nn.MSELoss() - mae = torch.nn.L1Loss() - # save model.pth torch.save(net.state_dict(), "checkpoints/model-0.pth") artifact = wandb.Artifact("pth", type="model") @@ -136,6 +132,9 @@ def main(): """ ) + # setup wandb table for saving images + table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) + try: for epoch in range(1, wandb.config.EPOCHS + 1): with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar: @@ -164,9 +163,9 @@ def main(): grad_scaler.update() # compute metrics - accuracy = (true_masks == pred_masks).float().mean() - mse = torch.nn.functional.mse_loss(pred_masks, true_masks) - mae = torch.nn.functional.l1_loss(pred_masks, true_masks) + pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float() + accuracy = (true_masks == pred_masks_bin).float().mean() + mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks) # update tqdm progress bar pbar.update(images.shape[0]) @@ -177,23 +176,64 @@ def main(): { "train/epoch": epoch - 1 + step / len(train_loader), "train/accuracy": accuracy, - "train/loss": train_loss, - "train/mse": mse, + "train/bce": train_loss, "train/mae": mae, } ) # Evaluation round - val_score = evaluate(net, val_loader, device) - scheduler.step(val_score) + net.eval() + accuracy = 0 + dice = 0 + mae = 0 + with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar: + for images, masks_true in val_loader: + + # transfer images to device + images = images.to(device=device) + masks_true = masks_true.unsqueeze(1).to(device=device) + + # forward + with torch.inference_mode(): + masks_pred = net(images) + + # compute metrics + masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + accuracy += (true_masks == pred_masks_bin).float().sum() + dice += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False) + mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks, reduction="sum") + + # update progress bar + pbar.update(images.shape[0]) + + accuracy /= len(ds_valid) + dice /= len(val_loader) # TODO: fix dice_coeff to not average + mae /= len(ds_valid) + + # save the last validation batch to table + for i, (img, mask, pred) in enumerate( + zip( + images.to("cpu"), + masks_true.to("cpu"), + masks_pred.to("cpu"), + ) + ): + table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred)) # log validation metrics wandb.log( { - "val/val_score": val_score, + "val/predictions": table, + "val/accuracy": accuracy, + "val/dice": dice, + "val/mae": mae, } ) + # update hyperparameters + net.train() + scheduler.step(dice) + # save weights when epoch end torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth") artifact = wandb.Artifact("pth", type="model") diff --git a/src/utils/dataset.py b/src/utils/dataset.py index f1cb2be..a798731 100644 --- a/src/utils/dataset.py +++ b/src/utils/dataset.py @@ -1,4 +1,3 @@ -import logging import os import numpy as np @@ -24,6 +23,10 @@ class SphereDataset(Dataset): if self.transform is not None: augmentations = self.transform(image=image, mask=mask) image = augmentations["image"] - mask = augmentations["mask"].float() + mask = augmentations["mask"] + + # make sure image and mask are floats + image = image.float() + mask = mask.float() return image, mask diff --git a/src/utils/paste.py b/src/utils/paste.py index 8d5f904..486a8ec 100644 --- a/src/utils/paste.py +++ b/src/utils/paste.py @@ -42,6 +42,7 @@ class RandomPaste(A.DualTransform): # convert img to Image, needed for `paste` function img = Image.fromarray(img) + # paste spheres for pos in positions: img.paste(paste_img, pos, paste_mask) @@ -51,8 +52,12 @@ class RandomPaste(A.DualTransform): # convert mask to Image, needed for `paste` function mask = Image.fromarray(mask) + # binarize the mask -> {0, 1} + paste_mask_bin = paste_mask.point(lambda p: 1 if p > 10 else 0) + + # paste spheres for pos in positions: - mask.paste(paste_mask, pos, paste_mask) + mask.paste(paste_mask, pos, paste_mask_bin) return np.asarray(mask.convert("L"))