feat: add some comments to evaluate.py
Former-commit-id: 0181b044ec0497d95c80e9e7c5112537806e0afc [formerly 1f62452a8fa1bbf6cfa5c82347c58cc1886ae796] Former-commit-id: 31dba554cbd0126009890ac54dfd6b6d38f2a47c
This commit is contained in:
parent
92ac3a2ab8
commit
806e98c744
|
@ -11,20 +11,21 @@ def evaluate(net, dataloader, device):
|
|||
dice_score = 0
|
||||
|
||||
# iterate over the validation set
|
||||
with tqdm(dataloader, total=len(dataloader.dataset), desc="Validation", unit="img", leave=False) as pbar:
|
||||
with tqdm(dataloader, total=len(dataloader.dataset), desc="val", unit="img", leave=False) as pbar:
|
||||
for images, masks_true in dataloader:
|
||||
# move images and labels to correct device
|
||||
images = images.to(device=device)
|
||||
masks_true = masks_true.unsqueeze(1).to(device=device)
|
||||
masks_true = masks_true.unsqueeze(1).float().to(device=device)
|
||||
|
||||
# forward, predict the mask
|
||||
with torch.inference_mode():
|
||||
# predict the mask
|
||||
masks_pred = net(images)
|
||||
masks_pred = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
|
||||
# compute the Dice score
|
||||
dice_score += dice_coeff(masks_pred, masks_true, reduce_batch_first=False)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(images.shape[0])
|
||||
|
||||
# save some images to wandb
|
||||
|
|
Loading…
Reference in a new issue