2018-04-09 03:15:24 +00:00
|
|
|
import sys
|
2018-06-08 17:27:32 +00:00
|
|
|
import os
|
2018-04-09 03:15:24 +00:00
|
|
|
from optparse import OptionParser
|
2018-06-08 17:27:32 +00:00
|
|
|
import numpy as np
|
2018-04-09 03:15:24 +00:00
|
|
|
|
2017-08-17 19:16:19 +00:00
|
|
|
import torch
|
2017-08-19 08:59:51 +00:00
|
|
|
import torch.backends.cudnn as cudnn
|
|
|
|
import torch.nn as nn
|
2018-04-09 03:15:24 +00:00
|
|
|
import torch.nn.functional as F
|
|
|
|
from torch import optim
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2017-08-19 08:59:51 +00:00
|
|
|
from eval import eval_net
|
2017-11-30 05:45:19 +00:00
|
|
|
from unet import UNet
|
2018-06-08 17:27:32 +00:00
|
|
|
from utils import get_ids, split_ids, split_train_val, get_imgs_and_masks, batch
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
def train_net(net,
|
|
|
|
epochs=5,
|
|
|
|
batch_size=1,
|
|
|
|
lr=0.1,
|
|
|
|
val_percent=0.05,
|
|
|
|
save_cp=True,
|
|
|
|
gpu=False,
|
|
|
|
img_scale=0.5):
|
2017-08-17 19:16:19 +00:00
|
|
|
|
|
|
|
dir_img = 'data/train/'
|
|
|
|
dir_mask = 'data/train_masks/'
|
|
|
|
dir_checkpoint = 'checkpoints/'
|
|
|
|
|
|
|
|
ids = get_ids(dir_img)
|
|
|
|
ids = split_ids(ids)
|
|
|
|
|
|
|
|
iddataset = split_train_val(ids, val_percent)
|
|
|
|
|
|
|
|
print('''
|
|
|
|
Starting training:
|
|
|
|
Epochs: {}
|
|
|
|
Batch size: {}
|
|
|
|
Learning rate: {}
|
|
|
|
Training size: {}
|
|
|
|
Validation size: {}
|
|
|
|
Checkpoints: {}
|
|
|
|
CUDA: {}
|
|
|
|
'''.format(epochs, batch_size, lr, len(iddataset['train']),
|
2018-06-08 17:27:32 +00:00
|
|
|
len(iddataset['val']), str(save_cp), str(gpu)))
|
2017-08-17 19:16:19 +00:00
|
|
|
|
|
|
|
N_train = len(iddataset['train'])
|
|
|
|
|
2017-08-19 08:59:51 +00:00
|
|
|
optimizer = optim.SGD(net.parameters(),
|
2018-06-08 17:27:32 +00:00
|
|
|
lr=lr,
|
|
|
|
momentum=0.9,
|
|
|
|
weight_decay=0.0005)
|
|
|
|
|
2017-08-19 08:59:51 +00:00
|
|
|
criterion = nn.BCELoss()
|
2017-08-17 19:16:19 +00:00
|
|
|
|
|
|
|
for epoch in range(epochs):
|
2018-04-09 03:15:24 +00:00
|
|
|
print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
|
2018-09-26 06:58:49 +00:00
|
|
|
net.train()
|
2018-04-09 03:15:24 +00:00
|
|
|
|
|
|
|
# reset the generators
|
2018-06-08 17:27:32 +00:00
|
|
|
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale)
|
|
|
|
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale)
|
2017-08-17 19:16:19 +00:00
|
|
|
|
|
|
|
epoch_loss = 0
|
|
|
|
|
|
|
|
for i, b in enumerate(batch(train, batch_size)):
|
2018-06-08 17:27:32 +00:00
|
|
|
imgs = np.array([i[0] for i in b]).astype(np.float32)
|
|
|
|
true_masks = np.array([i[1] for i in b])
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
imgs = torch.from_numpy(imgs)
|
|
|
|
true_masks = torch.from_numpy(true_masks)
|
2017-08-17 19:16:19 +00:00
|
|
|
|
|
|
|
if gpu:
|
2018-06-08 17:27:32 +00:00
|
|
|
imgs = imgs.cuda()
|
|
|
|
true_masks = true_masks.cuda()
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
masks_pred = net(imgs)
|
|
|
|
masks_probs = F.sigmoid(masks_pred)
|
|
|
|
masks_probs_flat = masks_probs.view(-1)
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
true_masks_flat = true_masks.view(-1)
|
2017-08-19 08:59:51 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
loss = criterion(masks_probs_flat, true_masks_flat)
|
|
|
|
epoch_loss += loss.item()
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item()))
|
2017-08-19 08:59:51 +00:00
|
|
|
|
|
|
|
optimizer.zero_grad()
|
2017-08-17 19:16:19 +00:00
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
|
2018-04-09 03:15:24 +00:00
|
|
|
print('Epoch finished ! Loss: {}'.format(epoch_loss / i))
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
if 1:
|
|
|
|
val_dice = eval_net(net, val, gpu)
|
|
|
|
print('Validation Dice Coeff: {}'.format(val_dice))
|
|
|
|
|
|
|
|
if save_cp:
|
2017-08-17 19:16:19 +00:00
|
|
|
torch.save(net.state_dict(),
|
2018-04-09 03:15:24 +00:00
|
|
|
dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
|
|
|
|
print('Checkpoint {} saved !'.format(epoch + 1))
|
2017-08-17 19:16:19 +00:00
|
|
|
|
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
|
|
|
|
def get_args():
|
2017-08-19 08:59:51 +00:00
|
|
|
parser = OptionParser()
|
|
|
|
parser.add_option('-e', '--epochs', dest='epochs', default=5, type='int',
|
|
|
|
help='number of epochs')
|
|
|
|
parser.add_option('-b', '--batch-size', dest='batchsize', default=10,
|
|
|
|
type='int', help='batch size')
|
|
|
|
parser.add_option('-l', '--learning-rate', dest='lr', default=0.1,
|
|
|
|
type='float', help='learning rate')
|
|
|
|
parser.add_option('-g', '--gpu', action='store_true', dest='gpu',
|
|
|
|
default=False, help='use cuda')
|
|
|
|
parser.add_option('-c', '--load', dest='load',
|
|
|
|
default=False, help='load file model')
|
2018-06-08 17:27:32 +00:00
|
|
|
parser.add_option('-s', '--scale', dest='scale', type='float',
|
|
|
|
default=0.5, help='downscaling factor of the images')
|
2017-08-19 08:59:51 +00:00
|
|
|
|
|
|
|
(options, args) = parser.parse_args()
|
2018-06-08 17:27:32 +00:00
|
|
|
return options
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
args = get_args()
|
2017-08-19 08:59:51 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
net = UNet(n_channels=3, n_classes=1)
|
2017-08-19 08:59:51 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
if args.load:
|
|
|
|
net.load_state_dict(torch.load(args.load))
|
|
|
|
print('Model loaded from {}'.format(args.load))
|
2017-08-19 08:59:51 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
if args.gpu:
|
2017-08-19 08:59:51 +00:00
|
|
|
net.cuda()
|
2018-06-08 17:27:32 +00:00
|
|
|
# cudnn.benchmark = True # faster convolutions, but more memory
|
2017-08-19 08:59:51 +00:00
|
|
|
|
|
|
|
try:
|
2018-06-08 17:27:32 +00:00
|
|
|
train_net(net=net,
|
|
|
|
epochs=args.epochs,
|
|
|
|
batch_size=args.batchsize,
|
|
|
|
lr=args.lr,
|
|
|
|
gpu=args.gpu,
|
|
|
|
img_scale=args.scale)
|
2017-08-19 08:59:51 +00:00
|
|
|
except KeyboardInterrupt:
|
|
|
|
torch.save(net.state_dict(), 'INTERRUPTED.pth')
|
|
|
|
print('Saved interrupt')
|
|
|
|
try:
|
|
|
|
sys.exit(0)
|
|
|
|
except SystemExit:
|
|
|
|
os._exit(0)
|