From 5a7e9345606eea5d9a469bd16dfb681b357195aa Mon Sep 17 00:00:00 2001 From: milesial Date: Mon, 2 Dec 2019 12:06:29 +0100 Subject: [PATCH] Fix eval.py iterating within batch Former-commit-id: d4b8040593d4bc82774b1dac647c701b2486c477 --- eval.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/eval.py b/eval.py index f293bd1..003d782 100644 --- a/eval.py +++ b/eval.py @@ -20,12 +20,12 @@ def eval_net(net, loader, device, n_val): mask_pred = net(imgs) - for true_mask in true_masks: - mask_pred = (mask_pred > 0.5).float() + for true_mask, pred in zip(true_masks, mask_pred): + pred = (pred > 0.5).float() if net.n_classes > 1: - tot += F.cross_entropy(mask_pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item() + tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item() else: - tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item() + tot += dice_coeff(pred, true_mask.squeeze(dim=1)).item() pbar.update(imgs.shape[0]) return tot / n_val