Better support for multiclass
Former-commit-id: 76bebf5f241f579fda7048f5e4a87ee9d49aa423
This commit is contained in:
parent
6f23624412
commit
8ed1e09b2a
7
eval.py
7
eval.py
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from dice_loss import dice_coeff
|
from dice_loss import dice_coeff
|
||||||
|
@ -22,5 +23,9 @@ def eval_net(net, dataset, device, n_val):
|
||||||
mask_pred = net(img).squeeze(dim=0)
|
mask_pred = net(img).squeeze(dim=0)
|
||||||
|
|
||||||
mask_pred = (mask_pred > 0.5).float()
|
mask_pred = (mask_pred > 0.5).float()
|
||||||
tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item()
|
if net.n_classes > 1:
|
||||||
|
tot += F.cross_entropy(mask_pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
|
||||||
|
else:
|
||||||
|
tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item()
|
||||||
|
|
||||||
return tot / n_val
|
return tot / n_val
|
||||||
|
|
13
train.py
13
train.py
|
@ -44,7 +44,10 @@ def train_net(net,
|
||||||
n_train = len(iddataset['train'])
|
n_train = len(iddataset['train'])
|
||||||
n_val = len(iddataset['val'])
|
n_val = len(iddataset['val'])
|
||||||
optimizer = optim.Adam(net.parameters(), lr=lr)
|
optimizer = optim.Adam(net.parameters(), lr=lr)
|
||||||
criterion = nn.BCELoss()
|
if net.n_classes > 1:
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
else:
|
||||||
|
criterion = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
net.train()
|
net.train()
|
||||||
|
@ -87,8 +90,12 @@ def train_net(net,
|
||||||
dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
|
dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
|
||||||
logging.info(f'Checkpoint {epoch + 1} saved !')
|
logging.info(f'Checkpoint {epoch + 1} saved !')
|
||||||
|
|
||||||
val_dice = eval_net(net, val, device, n_val)
|
val_score = eval_net(net, val, device, n_val)
|
||||||
logging.info('Validation Dice Coeff: {}'.format(val_dice))
|
if net.n_classes > 1:
|
||||||
|
logging.info('Validation cross entropy: {}'.format(val_score))
|
||||||
|
|
||||||
|
else:
|
||||||
|
logging.info('Validation Dice Coeff: {}'.format(val_score))
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
|
|
@ -33,7 +33,9 @@ class UNet(nn.Module):
|
||||||
x = self.up2(x, x3)
|
x = self.up2(x, x3)
|
||||||
x = self.up3(x, x2)
|
x = self.up3(x, x2)
|
||||||
x = self.up4(x, x1)
|
x = self.up4(x, x1)
|
||||||
x = self.outc(x)
|
logits = self.outc(x)
|
||||||
|
return logits
|
||||||
|
|
||||||
if self.n_classes > 1:
|
if self.n_classes > 1:
|
||||||
return F.softmax(x, dim=1)
|
return F.softmax(x, dim=1)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in a new issue