diff --git a/eval.py b/eval.py index 3b1867f..d126c1e 100644 --- a/eval.py +++ b/eval.py @@ -29,4 +29,5 @@ def eval_net(net, loader, device): tot += dice_coeff(pred, true_masks).item() pbar.update() + net.train() return tot / n_val