From 8ed1e09b2ab8da009a805fd404cb823b1d8c5b02 Mon Sep 17 00:00:00 2001 From: milesial Date: Wed, 30 Oct 2019 19:54:57 +0100 Subject: [PATCH] Better support for multiclass Former-commit-id: 76bebf5f241f579fda7048f5e4a87ee9d49aa423 --- eval.py | 7 ++++++- train.py | 13 ++++++++++--- unet/unet_model.py | 4 +++- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/eval.py b/eval.py index 2944bea..e48d522 100644 --- a/eval.py +++ b/eval.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F from tqdm import tqdm from dice_loss import dice_coeff @@ -22,5 +23,9 @@ def eval_net(net, dataset, device, n_val): mask_pred = net(img).squeeze(dim=0) mask_pred = (mask_pred > 0.5).float() - tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item() + if net.n_classes > 1: + tot += F.cross_entropy(mask_pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item() + else: + tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item() + return tot / n_val diff --git a/train.py b/train.py index 379a683..d875b64 100644 --- a/train.py +++ b/train.py @@ -44,7 +44,10 @@ def train_net(net, n_train = len(iddataset['train']) n_val = len(iddataset['val']) optimizer = optim.Adam(net.parameters(), lr=lr) - criterion = nn.BCELoss() + if net.n_classes > 1: + criterion = nn.CrossEntropyLoss() + else: + criterion = nn.BCEWithLogitsLoss() for epoch in range(epochs): net.train() @@ -87,8 +90,12 @@ def train_net(net, dir_checkpoint + f'CP_epoch{epoch + 1}.pth') logging.info(f'Checkpoint {epoch + 1} saved !') - val_dice = eval_net(net, val, device, n_val) - logging.info('Validation Dice Coeff: {}'.format(val_dice)) + val_score = eval_net(net, val, device, n_val) + if net.n_classes > 1: + logging.info('Validation cross entropy: {}'.format(val_score)) + + else: + logging.info('Validation Dice Coeff: {}'.format(val_score)) def get_args(): diff --git a/unet/unet_model.py b/unet/unet_model.py index 466222a..5b23b55 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -33,7 +33,9 @@ class UNet(nn.Module): x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) - x = self.outc(x) + logits = self.outc(x) + return logits + if self.n_classes > 1: return F.softmax(x, dim=1) else: