feat: add some comments to evaluate.py

Former-commit-id: 0181b044ec0497d95c80e9e7c5112537806e0afc [formerly 1f62452a8fa1bbf6cfa5c82347c58cc1886ae796]
Former-commit-id: 31dba554cbd0126009890ac54dfd6b6d38f2a47c
This commit is contained in:
Laurent Fainsin 2022-06-29 10:10:56 +02:00
parent 92ac3a2ab8
commit 806e98c744

View file

@ -11,20 +11,21 @@ def evaluate(net, dataloader, device):
dice_score = 0
# iterate over the validation set
with tqdm(dataloader, total=len(dataloader.dataset), desc="Validation", unit="img", leave=False) as pbar:
with tqdm(dataloader, total=len(dataloader.dataset), desc="val", unit="img", leave=False) as pbar:
for images, masks_true in dataloader:
# move images and labels to correct device
images = images.to(device=device)
masks_true = masks_true.unsqueeze(1).to(device=device)
masks_true = masks_true.unsqueeze(1).float().to(device=device)
# forward, predict the mask
with torch.inference_mode():
# predict the mask
masks_pred = net(images)
masks_pred = (torch.sigmoid(masks_pred) > 0.5).float()
# compute the Dice score
dice_score += dice_coeff(masks_pred, masks_true, reduce_batch_first=False)
# update progress bar
pbar.update(images.shape[0])
# save some images to wandb