diff --git a/crf.py b/crf.py new file mode 100644 index 0000000..06cb61e --- /dev/null +++ b/crf.py @@ -0,0 +1,28 @@ +import numpy as np +import pydensecrf.densecrf as dcrf + + +def dense_crf(img, output_probs): + h = output_probs.shape[0] + w = output_probs.shape[1] + + output_probs = np.expand_dims(output_probs, 0) + output_probs = np.append(1 - output_probs, output_probs, axis=0) + print(output_probs.shape) + + d = dcrf.DenseCRF2D(w, h, 2) + U = -np.log(output_probs) + U = U.reshape((2, -1)) + U = np.ascontiguousarray(U) + img = np.ascontiguousarray(img) + + + d.setUnaryEnergy(U) + + d.addPairwiseGaussian(sxy=10, compat=3) + d.addPairwiseBilateral(sxy=50, srgb=20, rgbim=img, compat=10) + + Q = d.inference(30) + Q = np.argmax(np.array(Q), axis=0).reshape((h, w)) + + return Q diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..cf2ea37 --- /dev/null +++ b/eval.py @@ -0,0 +1,56 @@ +import torch +from myloss import dice_coeff +import numpy as np +from torch.autograd import Variable +from data_vis import plot_img_mask +import matplotlib.pyplot as plt +import torch.nn.functional as F +from crf import dense_crf + + +def eval_net(net, dataset, gpu=False): + tot = 0 + for i, b in enumerate(dataset): + X = b[0] + y = b[1] + + X = torch.FloatTensor(X).unsqueeze(0) + y = torch.ByteTensor(y).unsqueeze(0) + + if gpu: + X = Variable(X, volatile=True).cuda() + y = Variable(y, volatile=True).cuda() + else: + X = Variable(X, volatile=True) + y = Variable(y, volatile=True) + + y_pred = net(X) + + y_pred = (F.sigmoid(y_pred) > 0.6).float() + # y_pred = F.sigmoid(y_pred).float() + + dice = dice_coeff(y_pred, y.float()).data[0] + tot += dice + + if 0: + X = X.data.squeeze(0).cpu().numpy() + X = np.transpose(X, axes=[1, 2, 0]) + y = y.data.squeeze(0).cpu().numpy() + y_pred = y_pred.data.squeeze(0).squeeze(0).cpu().numpy() + print(y_pred.shape) + + fig = plt.figure() + ax1 = fig.add_subplot(1, 4, 1) + ax1.imshow(X) + ax2 = fig.add_subplot(1, 4, 2) + ax2.imshow(y) + ax3 = fig.add_subplot(1, 4, 3) + ax3.imshow((y_pred > 0.6)) + + + Q = dense_crf(((X*255).round()).astype(np.uint8), y_pred) + ax4 = fig.add_subplot(1, 4, 4) + print(Q) + ax4.imshow(Q) + plt.show() + return tot / i diff --git a/main.py b/main.py index 9df9eef..6796781 100644 --- a/main.py +++ b/main.py @@ -8,17 +8,20 @@ from torch import optim #data manipulation import numpy as np import pandas as pd -import cv2 import PIL #load files import os -#data vis +#data visualization from data_vis import plot_img_mask from utils import * import matplotlib.pyplot as plt +#quit after interrupt +import sys + + dir = 'data' ids = [] @@ -33,69 +36,71 @@ np.random.shuffle(ids) net = UNet(3, 1) +net.cuda() -optimizer = optim.Adam(net.parameters(), lr=0.001) -criterion = DiceLoss() +def train(net): + optimizer = optim.Adam(net.parameters(), lr=1) + criterion = DiceLoss() -dataset = [] -epochs = 5 -for epoch in range(epochs): - print('epoch {}/{}...'.format(epoch+1, epochs)) - l = 0 + epochs = 5 + for epoch in range(epochs): + print('epoch {}/{}...'.format(epoch+1, epochs)) + l = 0 - for i, c in enumerate(ids): - id = c[0] - pos = c[1] - im = PIL.Image.open(dir + '/train/' + id + '.jpg') - im = resize_and_crop(im) + for i, c in enumerate(ids): + id = c[0] + pos = c[1] + im = PIL.Image.open(dir + '/train/' + id + '.jpg') + im = resize_and_crop(im) - ma = PIL.Image.open(dir + '/train_masks/' + id + '_mask.gif') - ma = resize_and_crop(ma) + ma = PIL.Image.open(dir + '/train_masks/' + id + '_mask.gif') + ma = resize_and_crop(ma) - left, right = split_into_squares(np.array(im)) - left_m, right_m = split_into_squares(np.array(ma)) + left, right = split_into_squares(np.array(im)) + left_m, right_m = split_into_squares(np.array(ma)) - if pos == 0: - X = left - y = left_m - else: - X = right - y = right_m + if pos == 0: + X = left + y = left_m + else: + X = right + y = right_m - X = np.transpose(X, axes=[2, 0, 1]) - X = torch.FloatTensor(X / 255).unsqueeze(0) - y = Variable(torch.ByteTensor(y)) + X = np.transpose(X, axes=[2, 0, 1]) + X = torch.FloatTensor(X / 255).unsqueeze(0).cuda() + y = Variable(torch.ByteTensor(y)).cuda() - X = Variable(X, requires_grad=False) + X = Variable(X).cuda() - optimizer.zero_grad() + optimizer.zero_grad() - y_pred = net(X).squeeze(1) + y_pred = net(X).squeeze(1) - loss = criterion(y_pred, y.unsqueeze(0).float()) + loss = criterion(y_pred, y.unsqueeze(0).float()) - l += loss.data[0] - loss.backward() - optimizer.step() + l += loss.data[0] + loss.backward() + if i%10 == 0: + optimizer.step() + print('Stepped') - print('{0:.4f}%.'.format(i/len(ids)*100, end='')) + print('{0:.4f}%\t\t{1:.6f}'.format(i/len(ids)*100, loss.data[0])) - print('Loss : {}'.format(l)) + l = l / len(ids) + print('Loss : {}'.format(l)) + torch.save(net.state_dict(), 'MODEL_EPOCH{}_LOSS{}.pth'.format(epoch+1, l)) + print('Saved') +try: + net.load_state_dict(torch.load('MODEL_INTERRUPTED.pth')) + train(net) -#%% - - - - -#net = UNet(3, 2) - -#x = Variable(torch.FloatTensor(np.random.randn(1, 3, 640, 640))) - -#y = net(x) - - -#plt.imshow(y[0]) -#plt.show() +except KeyboardInterrupt: + print('Interrupted') + torch.save(net.state_dict(), 'MODEL_INTERRUPTED.pth') + try: + sys.exit(0) + except SystemExit: + os._exit(0) diff --git a/train.py b/train.py index 2dedee2..c09d1d3 100644 --- a/train.py +++ b/train.py @@ -1,13 +1,19 @@ import torch +import torch.backends.cudnn as cudnn +import torch.nn.functional as F +import torch.nn as nn from load import * from data_vis import * from utils import split_train_val, batch from myloss import DiceLoss +from eval import eval_net from unet_model import UNet from torch.autograd import Variable from torch import optim from optparse import OptionParser +import sys +import os def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05, @@ -39,14 +45,21 @@ def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05, train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask) val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask) - optimizer = optim.Adam(net.parameters(), lr=lr) - criterion = DiceLoss() + optimizer = optim.SGD(net.parameters(), + lr=lr, momentum=0.9, weight_decay=0.0005) + criterion = nn.BCELoss() for epoch in range(epochs): print('Starting epoch {}/{}.'.format(epoch+1, epochs)) + train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask) + val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask) epoch_loss = 0 + if 1: + val_dice = eval_net(net, val, gpu) + print('Validation Dice Coeff: {}'.format(val_dice)) + for i, b in enumerate(batch(train, batch_size)): X = np.array([i[0] for i in b]) y = np.array([i[1] for i in b]) @@ -61,17 +74,22 @@ def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05, X = Variable(X) y = Variable(y) - optimizer.zero_grad() - y_pred = net(X) + probs = F.sigmoid(y_pred) + probs_flat = probs.view(-1) - loss = criterion(y_pred, y.float()) + y_flat = y.view(-1) + + loss = criterion(probs_flat, y_flat.float()) epoch_loss += loss.data[0] print('{0:.4f} --- loss: {1:.6f}'.format(i*batch_size/N_train, - loss.data[0])) + loss.data[0])) + + optimizer.zero_grad() loss.backward() + optimizer.step() print('Epoch finished ! Loss: {}'.format(epoch_loss/i)) @@ -83,23 +101,38 @@ def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05, print('Checkpoint {} saved !'.format(epoch+1)) -parser = OptionParser() -parser.add_option("-e", "--epochs", dest="epochs", default=5, type="int", - help="number of epochs") -parser.add_option("-b", "--batch-size", dest="batchsize", default=10, - type="int", help="batch size") -parser.add_option("-l", "--learning-rate", dest="lr", default=0.1, - type="int", help="learning rate") -parser.add_option("-g", "--gpu", action="store_true", dest="gpu", - default=False, help="use cuda") -parser.add_option("-n", "--ngpu", action="store_false", dest="gpu", - default=False, help="use cuda") +if __name__ == '__main__': + parser = OptionParser() + parser.add_option('-e', '--epochs', dest='epochs', default=5, type='int', + help='number of epochs') + parser.add_option('-b', '--batch-size', dest='batchsize', default=10, + type='int', help='batch size') + parser.add_option('-l', '--learning-rate', dest='lr', default=0.1, + type='float', help='learning rate') + parser.add_option('-g', '--gpu', action='store_true', dest='gpu', + default=False, help='use cuda') + parser.add_option('-c', '--load', dest='load', + default=False, help='load file model') + (options, args) = parser.parse_args() -(options, args) = parser.parse_args() + net = UNet(3, 1) -net = UNet(3, 1) -if options.gpu: - net.cuda() + if options.load: + net.load_state_dict(torch.load(options.load)) + print('Model loaded from {}'.format(options.load)) -train_net(net, options.epochs, options.batchsize, options.lr, gpu=options.gpu) + if options.gpu: + net.cuda() + cudnn.benchmark = True + + try: + train_net(net, options.epochs, options.batchsize, options.lr, + gpu=options.gpu) + except KeyboardInterrupt: + torch.save(net.state_dict(), 'INTERRUPTED.pth') + print('Saved interrupt') + try: + sys.exit(0) + except SystemExit: + os._exit(0) diff --git a/unet_parts.py b/unet_parts.py index 08fd8d0..7efb750 100644 --- a/unet_parts.py +++ b/unet_parts.py @@ -10,9 +10,11 @@ class double_conv(nn.Module): super(double_conv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), - nn.ReLU(), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), - nn.ReLU() + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True) ) def forward(self, x): diff --git a/utils.py b/utils.py index b51cd74..1697178 100644 --- a/utils.py +++ b/utils.py @@ -13,7 +13,7 @@ def get_square(img, pos): return img[:, -h:] -def resize_and_crop(pilimg, scale=0.2, final_height=None): +def resize_and_crop(pilimg, scale=0.5, final_height=None): w = pilimg.size[0] h = pilimg.size[1] newW = int(w * scale) @@ -46,6 +46,7 @@ def split_train_val(dataset, val_percent=0.05): dataset = list(dataset) length = len(dataset) n = int(length * val_percent) + random.seed(42) random.shuffle(dataset) return {'train': dataset[:-n], 'val': dataset[-n:]}