diff --git a/eval.py b/eval.py index f293bd1..003d782 100644 --- a/eval.py +++ b/eval.py @@ -20,12 +20,12 @@ def eval_net(net, loader, device, n_val): mask_pred = net(imgs) - for true_mask in true_masks: - mask_pred = (mask_pred > 0.5).float() + for true_mask, pred in zip(true_masks, mask_pred): + pred = (pred > 0.5).float() if net.n_classes > 1: - tot += F.cross_entropy(mask_pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item() + tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item() else: - tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item() + tot += dice_coeff(pred, true_mask.squeeze(dim=1)).item() pbar.update(imgs.shape[0]) return tot / n_val