From 806e98c744ec0ce968bee44d7741b179aac07d41 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Wed, 29 Jun 2022 10:10:56 +0200 Subject: [PATCH] feat: add some comments to evaluate.py Former-commit-id: 0181b044ec0497d95c80e9e7c5112537806e0afc [formerly 1f62452a8fa1bbf6cfa5c82347c58cc1886ae796] Former-commit-id: 31dba554cbd0126009890ac54dfd6b6d38f2a47c --- src/evaluate.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/evaluate.py b/src/evaluate.py index af3d1dc..cd5db85 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -11,20 +11,21 @@ def evaluate(net, dataloader, device): dice_score = 0 # iterate over the validation set - with tqdm(dataloader, total=len(dataloader.dataset), desc="Validation", unit="img", leave=False) as pbar: + with tqdm(dataloader, total=len(dataloader.dataset), desc="val", unit="img", leave=False) as pbar: for images, masks_true in dataloader: # move images and labels to correct device images = images.to(device=device) - masks_true = masks_true.unsqueeze(1).to(device=device) + masks_true = masks_true.unsqueeze(1).float().to(device=device) + # forward, predict the mask with torch.inference_mode(): - # predict the mask masks_pred = net(images) masks_pred = (torch.sigmoid(masks_pred) > 0.5).float() # compute the Dice score dice_score += dice_coeff(masks_pred, masks_true, reduce_batch_first=False) + # update progress bar pbar.update(images.shape[0]) # save some images to wandb