From b71b57285f106c58f5b0e384bca7908a2b9bc7d0 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Wed, 6 Jul 2022 14:27:26 +0200 Subject: [PATCH] feat: using dice_loss feat: paste aug contrast/sharpness Former-commit-id: 93f19e9643858a81ace14e9a697dfb6b3cca4d47 [formerly f6ef5f65e84f37b4b55a99a49442b7d30d6d3911] Former-commit-id: 2f49a81340a91ab7456d093a849ed294457f8a83 --- src/train.py | 2 +- src/unet/model.py | 81 ++++++++++++-------------------------------- src/utils/dice.py | 84 ++++++---------------------------------------- src/utils/paste.py | 16 ++++++--- 4 files changed, 46 insertions(+), 137 deletions(-) diff --git a/src/train.py b/src/train.py index 656eea4..d97802d 100644 --- a/src/train.py +++ b/src/train.py @@ -23,7 +23,7 @@ CONFIG = { "BENCHMARK": True, "DEVICE": "gpu", "WORKERS": 8, - "EPOCHS": 5, + "EPOCHS": 10, "BATCH_SIZE": 16, "LEARNING_RATE": 1e-4, "WEIGHT_DECAY": 1e-8, diff --git a/src/unet/model.py b/src/unet/model.py index ccc036c..4e87827 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -9,7 +9,7 @@ from torch.utils.data import DataLoader import wandb from src.utils.dataset import SphereDataset -from utils.dice import dice_coeff +from utils.dice import dice_loss from utils.paste import RandomPaste from .blocks import * @@ -111,28 +111,29 @@ class UNet(pl.LightningModule): # forward pass masks_pred = self(images) - # compute loss + # compute metrics bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true) + dice = dice_loss(masks_pred, masks_true) - # compute other metrics masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False) mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) accuracy = (masks_true == masks_pred_bin).float().mean() - dice = dice_coeff(masks_pred_bin, masks_true) self.log_dict( { "train/accuracy": accuracy, - "train/bce": bce, "train/dice": dice, + "train/dice_bin": dice_bin, + "train/bce": bce, "train/mae": mae, }, ) return dict( - loss=bce, - dice=dice, accuracy=accuracy, + loss=dice, + bce=bce, mae=mae, ) @@ -144,17 +145,17 @@ class UNet(pl.LightningModule): # forward pass masks_pred = self(images) - # compute loss + # compute metrics bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true) + dice = dice_loss(masks_pred, masks_true) - # compute other metrics masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False) mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) accuracy = (masks_true == masks_pred_bin).float().mean() - dice = dice_coeff(masks_pred_bin, masks_true) rows = [] - if batch_idx < 6: + if batch_idx % 50 == 0: for i, (img, mask, pred, pred_bin) in enumerate( zip( images.cpu(), @@ -181,9 +182,10 @@ class UNet(pl.LightningModule): ) return dict( - loss=bce, - dice=dice, accuracy=accuracy, + loss=dice, + dice_bin=dice_bin, + bce=bce, mae=mae, table_rows=rows, ) @@ -191,8 +193,9 @@ class UNet(pl.LightningModule): def validation_epoch_end(self, validation_outputs): # matrics unpacking accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean() + dice_bin = torch.stack([d["dice_bin"] for d in validation_outputs]).mean() loss = torch.stack([d["loss"] for d in validation_outputs]).mean() - dice = torch.stack([d["dice"] for d in validation_outputs]).mean() + bce = torch.stack([d["bce"] for d in validation_outputs]).mean() mae = torch.stack([d["mae"] for d in validation_outputs]).mean() # table unpacking @@ -201,7 +204,7 @@ class UNet(pl.LightningModule): rows = list(itertools.chain.from_iterable(rowss)) # logging - try: + try: # required by autofinding, logger replaced by dummy self.logger.log_table( key="val/predictions", columns=columns, @@ -209,11 +212,13 @@ class UNet(pl.LightningModule): ) except: pass + self.log_dict( { "val/accuracy": accuracy, - "val/bce": loss, - "val/dice": dice, + "val/dice": loss, + "val/dice_bin": dice_bin, + "val/bce": bce, "val/mae": mae, } ) @@ -231,48 +236,6 @@ class UNet(pl.LightningModule): artifact.add_file(f"checkpoints/model.onnx") wandb.run.log_artifact(artifact) - # def test_step(self, batch, batch_idx): - # # unpacking - # images, masks_true = batch - # masks_true = masks_true.unsqueeze(1) - # masks_pred = self(images) - # masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - - # # compute metrics - # loss = F.cross_entropy(masks_pred, masks_true) - # mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) - # accuracy = (masks_true == masks_pred_bin).float().mean() - # dice = dice_coeff(masks_pred_bin, masks_true) - - # if batch_idx == 0: - # self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions") - - # return loss, dice, accuracy, mae - - # def test_step_end(self, test_outputs): - # # unpacking - # list_loss, list_dice, list_accuracy, list_mae = test_outputs - - # # averaging - # loss = np.mean(list_loss) - # dice = np.mean(list_dice) - # accuracy = np.mean(list_accuracy) - # mae = np.mean(list_mae) - - # # # get learning rate - # # optimizer = self.optimizers[0] - # # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] - - # wandb.log( - # { - # # "train/learning_rate": learning_rate, - # "test/accuracy": accuracy, - # "test/bce": loss, - # "test/dice": dice, - # "test/mae": mae, - # } - # ) - def configure_optimizers(self): optimizer = torch.optim.RMSprop( self.parameters(), diff --git a/src/utils/dice.py b/src/utils/dice.py index a29f794..acdc4b2 100644 --- a/src/utils/dice.py +++ b/src/utils/dice.py @@ -1,80 +1,18 @@ import torch -from torch import Tensor -def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float: - """Average of Dice coefficient for all batches, or for a single mask. +def dice_score(inputs, targets, smooth=1, logits=True): + # comment out if your model contains a sigmoid or equivalent activation layer + if logits: + inputs = torch.sigmoid(inputs) - Args: - input (Tensor): _description_ - target (Tensor): _description_ - reduce_batch_first (bool, optional): _description_. Defaults to False. - epsilon (_type_, optional): _description_. Defaults to 1e-6. + # flatten label and prediction tensors + inputs = inputs.view(-1) + targets = targets.view(-1) - Raises: - ValueError: _description_ - - Returns: - float: _description_ - """ - assert input.size() == target.size() - - if input.dim() == 2 and reduce_batch_first: - raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})") - - if input.dim() == 2 or reduce_batch_first: - inter = torch.dot(input.reshape(-1), target.reshape(-1)) - sets_sum = torch.sum(input) + torch.sum(target) - - if sets_sum.item() == 0: - sets_sum = 2 * inter - - return (2 * inter + epsilon) / (sets_sum + epsilon) - else: - # compute and average metric for each batch element - dice = 0 - - for i in range(input.shape[0]): - dice += dice_coeff(input[i, ...], target[i, ...]) - - return dice / input.shape[0] + intersection = (inputs * targets).sum() + return (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) -def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float: - """Average of Dice coefficient for all classes. - - Args: - input (Tensor): _description_ - target (Tensor): _description_ - reduce_batch_first (bool, optional): _description_. Defaults to False. - epsilon (_type_, optional): _description_. Defaults to 1e-6. - - Returns: - float: _description_ - """ - assert input.size() == target.size() - - dice = 0 - - for channel in range(input.shape[1]): - dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon) - - return dice / input.shape[1] - - -def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float: - """Dice loss (objective to minimize) between 0 and 1. - - Args: - input (Tensor): _description_ - target (Tensor): _description_ - multiclass (bool, optional): _description_. Defaults to False. - - Returns: - float: _description_ - """ - assert input.size() == target.size() - - fn = multiclass_dice_coeff if multiclass else dice_coeff - - return 1 - fn(input, target, reduce_batch_first=True) +def dice_loss(inputs, targets, smooth=1, logits=True): + return 1 - dice_score(inputs, targets, smooth, logits) diff --git a/src/utils/paste.py b/src/utils/paste.py index 486a8ec..a1e24e4 100644 --- a/src/utils/paste.py +++ b/src/utils/paste.py @@ -93,6 +93,18 @@ class RandomPaste(A.DualTransform): target_shape = np.array(target_img.shape[:2], dtype=np.uint) paste_shape = np.array(paste_img.size, dtype=np.uint) + # change paste_img's brightness randomly + filter = ImageEnhance.Brightness(paste_img) + paste_img = filter.enhance(rd.uniform(0.5, 1.5)) + + # change paste_img's contrast randomly + filter = ImageEnhance.Contrast(paste_img) + paste_img = filter.enhance(rd.uniform(0.5, 1.5)) + + # change paste_img's sharpness randomly + filter = ImageEnhance.Sharpness(paste_img) + paste_img = filter.enhance(rd.uniform(0.5, 1.5)) + # compute the minimum scaling to fit inside target image min_scale = np.min(target_shape / paste_shape) @@ -117,10 +129,6 @@ class RandomPaste(A.DualTransform): # update paste_shape after scaling paste_shape = np.array(paste_img.size, dtype=np.uint) - # change brightness randomly - filter = ImageEnhance.Brightness(paste_img) - paste_img = filter.enhance(rd.uniform(0.5, 1.5)) - # generate some positions positions = [] NB = rd.randint(1, self.nb)