feat: add some comments to evaluate.py
Former-commit-id: 0181b044ec0497d95c80e9e7c5112537806e0afc [formerly 1f62452a8fa1bbf6cfa5c82347c58cc1886ae796] Former-commit-id: 31dba554cbd0126009890ac54dfd6b6d38f2a47c
This commit is contained in:
parent
92ac3a2ab8
commit
806e98c744
|
@ -11,20 +11,21 @@ def evaluate(net, dataloader, device):
|
||||||
dice_score = 0
|
dice_score = 0
|
||||||
|
|
||||||
# iterate over the validation set
|
# iterate over the validation set
|
||||||
with tqdm(dataloader, total=len(dataloader.dataset), desc="Validation", unit="img", leave=False) as pbar:
|
with tqdm(dataloader, total=len(dataloader.dataset), desc="val", unit="img", leave=False) as pbar:
|
||||||
for images, masks_true in dataloader:
|
for images, masks_true in dataloader:
|
||||||
# move images and labels to correct device
|
# move images and labels to correct device
|
||||||
images = images.to(device=device)
|
images = images.to(device=device)
|
||||||
masks_true = masks_true.unsqueeze(1).to(device=device)
|
masks_true = masks_true.unsqueeze(1).float().to(device=device)
|
||||||
|
|
||||||
|
# forward, predict the mask
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# predict the mask
|
|
||||||
masks_pred = net(images)
|
masks_pred = net(images)
|
||||||
masks_pred = (torch.sigmoid(masks_pred) > 0.5).float()
|
masks_pred = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
|
|
||||||
# compute the Dice score
|
# compute the Dice score
|
||||||
dice_score += dice_coeff(masks_pred, masks_true, reduce_batch_first=False)
|
dice_score += dice_coeff(masks_pred, masks_true, reduce_batch_first=False)
|
||||||
|
|
||||||
|
# update progress bar
|
||||||
pbar.update(images.shape[0])
|
pbar.update(images.shape[0])
|
||||||
|
|
||||||
# save some images to wandb
|
# save some images to wandb
|
||||||
|
|
Loading…
Reference in a new issue