mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
Fix evaluation
Former-commit-id: 2b649bc9a337818696291280cd87ad93a6fc5032
This commit is contained in:
parent
890719f4b7
commit
8c761f9d06
|
@ -24,12 +24,12 @@ def evaluate(net, dataloader, device):
|
||||||
|
|
||||||
# convert to one-hot format
|
# convert to one-hot format
|
||||||
if net.n_classes == 1:
|
if net.n_classes == 1:
|
||||||
mask_pred = (F.sigmoid(mask_pred) > 0).float()
|
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
|
||||||
else:
|
else:
|
||||||
mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
|
mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
|
||||||
|
|
||||||
# compute the Dice score, ignoring background
|
# compute the Dice score, ignoring background
|
||||||
dice_score += multiclass_dice_coeff(mask_pred[:, :1, ...], mask_true[:, :1, ...], reduce_batch_first=False)
|
dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False)
|
||||||
|
|
||||||
net.train()
|
net.train()
|
||||||
return dice_score / num_val_batches
|
return dice_score / num_val_batches
|
||||||
|
|
Loading…
Reference in a new issue