From d292e8c6cd6111bb99a60bcccaed67e1bfefe42a Mon Sep 17 00:00:00 2001 From: whenyd Date: Wed, 11 Mar 2020 16:06:23 +0800 Subject: [PATCH 1/2] Apply sigmoid before calc dice in eval_net() Former-commit-id: 0da18fda34f29c81968425715e19c5dc76c9ec46 --- eval.py | 26 +++++++++++++------------- train.py | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/eval.py b/eval.py index 761a2d3..3b1867f 100644 --- a/eval.py +++ b/eval.py @@ -5,28 +5,28 @@ from tqdm import tqdm from dice_loss import dice_coeff -def eval_net(net, loader, device, n_val): +def eval_net(net, loader, device): """Evaluation without the densecrf with the dice coefficient""" net.eval() + mask_type = torch.float32 if net.n_classes == 1 else torch.long + n_val = len(loader) # the number of batch tot = 0 - with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar: + with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar: for batch in loader: - imgs = batch['image'] - true_masks = batch['mask'] - + imgs, true_masks = batch['image'], batch['mask'] imgs = imgs.to(device=device, dtype=torch.float32) - mask_type = torch.float32 if net.n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=mask_type) - mask_pred = net(imgs) + with torch.no_grad(): + mask_pred = net(imgs) - for true_mask, pred in zip(true_masks, mask_pred): + if net.n_classes > 1: + tot += F.cross_entropy(mask_pred, true_masks).item() + else: + pred = torch.sigmoid(mask_pred) pred = (pred > 0.5).float() - if net.n_classes > 1: - tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item() - else: - tot += dice_coeff(pred, true_mask.squeeze(dim=1)).item() - pbar.update(imgs.shape[0]) + tot += dice_coeff(pred, true_masks).item() + pbar.update() return tot / n_val diff --git a/train.py b/train.py index d52c8b6..5ac3f3b 100644 --- a/train.py +++ b/train.py @@ -88,7 +88,7 @@ def train_net(net, pbar.update(imgs.shape[0]) global_step += 1 if global_step % (len(dataset) // (10 * batch_size)) == 0: - val_score = eval_net(net, val_loader, device, n_val) + val_score = eval_net(net, val_loader, device) if net.n_classes > 1: logging.info('Validation cross entropy: {}'.format(val_score)) writer.add_scalar('Loss/test', val_score, global_step) From cf505c65495f5cbbad0790c631677c99b308b740 Mon Sep 17 00:00:00 2001 From: whenyd Date: Fri, 13 Mar 2020 11:21:16 +0800 Subject: [PATCH 2/2] Set `drop_last=True` for val_loader Former-commit-id: adb1d7e9348a2707f38d9b86f57bb3ae2cbc2b73 --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 5ac3f3b..90a60ae 100644 --- a/train.py +++ b/train.py @@ -35,7 +35,7 @@ def train_net(net, n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) - val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True) + val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True) writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') global_step = 0