diff --git a/src/evaluate.py b/src/evaluate.py index af3d1dc..cd5db85 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -11,20 +11,21 @@ def evaluate(net, dataloader, device): dice_score = 0 # iterate over the validation set - with tqdm(dataloader, total=len(dataloader.dataset), desc="Validation", unit="img", leave=False) as pbar: + with tqdm(dataloader, total=len(dataloader.dataset), desc="val", unit="img", leave=False) as pbar: for images, masks_true in dataloader: # move images and labels to correct device images = images.to(device=device) - masks_true = masks_true.unsqueeze(1).to(device=device) + masks_true = masks_true.unsqueeze(1).float().to(device=device) + # forward, predict the mask with torch.inference_mode(): - # predict the mask masks_pred = net(images) masks_pred = (torch.sigmoid(masks_pred) > 0.5).float() # compute the Dice score dice_score += dice_coeff(masks_pred, masks_true, reduce_batch_first=False) + # update progress bar pbar.update(images.shape[0]) # save some images to wandb