From 74ee006a7a4dbde4c9aef16680b1424632c7cfe4 Mon Sep 17 00:00:00 2001 From: milesial Date: Mon, 21 Aug 2017 18:00:07 +0200 Subject: [PATCH] Added simple predict + submit script --- crf.py | 7 +++---- eval.py | 6 +++--- load.py | 5 +++++ predict.py | 41 +++++++++++++++++++++++++++++++++++++++++ submit.py | 25 +++++++++++++++++++++++++ utils.py | 36 ++++++++++++++++++++++++++++++++++++ 6 files changed, 113 insertions(+), 7 deletions(-) create mode 100644 predict.py create mode 100644 submit.py diff --git a/crf.py b/crf.py index 06cb61e..713a47b 100644 --- a/crf.py +++ b/crf.py @@ -8,7 +8,6 @@ def dense_crf(img, output_probs): output_probs = np.expand_dims(output_probs, 0) output_probs = np.append(1 - output_probs, output_probs, axis=0) - print(output_probs.shape) d = dcrf.DenseCRF2D(w, h, 2) U = -np.log(output_probs) @@ -19,10 +18,10 @@ def dense_crf(img, output_probs): d.setUnaryEnergy(U) - d.addPairwiseGaussian(sxy=10, compat=3) - d.addPairwiseBilateral(sxy=50, srgb=20, rgbim=img, compat=10) + d.addPairwiseGaussian(sxy=20, compat=3) + d.addPairwiseBilateral(sxy=30, srgb=20, rgbim=img, compat=10) - Q = d.inference(30) + Q = d.inference(5) Q = np.argmax(np.array(Q), axis=0).reshape((h, w)) return Q diff --git a/eval.py b/eval.py index cf2ea37..ee6d666 100644 --- a/eval.py +++ b/eval.py @@ -32,7 +32,7 @@ def eval_net(net, dataset, gpu=False): dice = dice_coeff(y_pred, y.float()).data[0] tot += dice - if 0: + if 1: X = X.data.squeeze(0).cpu().numpy() X = np.transpose(X, axes=[1, 2, 0]) y = y.data.squeeze(0).cpu().numpy() @@ -45,12 +45,12 @@ def eval_net(net, dataset, gpu=False): ax2 = fig.add_subplot(1, 4, 2) ax2.imshow(y) ax3 = fig.add_subplot(1, 4, 3) - ax3.imshow((y_pred > 0.6)) + ax3.imshow((y_pred > 0.5)) Q = dense_crf(((X*255).round()).astype(np.uint8), y_pred) ax4 = fig.add_subplot(1, 4, 4) print(Q) - ax4.imshow(Q) + ax4.imshow(Q > 0.5) plt.show() return tot / i diff --git a/load.py b/load.py index c4418c1..9b847d0 100644 --- a/load.py +++ b/load.py @@ -40,3 +40,8 @@ def get_imgs_and_masks(ids, dir_img, dir_mask): masks = to_cropped_imgs(ids, dir_mask, '_mask.gif') return zip(imgs_normalized, masks) + +def get_full_img_and_mask(id, dir_img, dir_mask): + im = Image.open(dir_img + id + '.jpg') + mask = Image.open(dir_mask + id + '_mask.gif') + return np.array(im), np.array(mask) diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..26fdeff --- /dev/null +++ b/predict.py @@ -0,0 +1,41 @@ +import torch +from utils import * +import torch.nn.functional as F +from PIL import Image +from unet_model import UNet +from torch.autograd import Variable +import matplotlib.pyplot as plt +from crf import dense_crf + + +def predict_img(net, full_img, gpu=False): + img = resize_and_crop(full_img) + + left = get_square(img, 0) + right = get_square(img, 1) + + right = normalize(right) + left = normalize(left) + + right = np.transpose(right, axes=[2, 0, 1]) + left = np.transpose(left, axes=[2, 0, 1]) + + X_l = torch.FloatTensor(left).unsqueeze(0) + X_r = torch.FloatTensor(right).unsqueeze(0) + + if gpu: + X_l = Variable(X_l, volatile=True).cuda() + X_r = Variable(X_r, volatile=True).cuda() + else: + X_l = Variable(X_l, volatile=True) + X_r = Variable(X_r, volatile=True) + + y_l = F.sigmoid(net(X_l)) + y_r = F.sigmoid(net(X_r)) + y_l = F.upsample_bilinear(y_l, scale_factor=2).data[0][0].cpu().numpy() + y_r = F.upsample_bilinear(y_r, scale_factor=2).data[0][0].cpu().numpy() + + y = merge_masks(y_l, y_r, 1918) + yy = dense_crf(np.array(full_img).astype(np.uint8), y) + + return yy > 0.5 diff --git a/submit.py b/submit.py new file mode 100644 index 0000000..14b0396 --- /dev/null +++ b/submit.py @@ -0,0 +1,25 @@ +import os +from PIL import Image +from predict import * +from utils import encode +from unet_model import UNet + +def submit(net, gpu=False): + dir = 'data/test/' + + N = len(list(os.listdir(dir))) + with open('SUBMISSION.csv', 'w') as f: + f.write('img,rle_mask\n') + for index, i in enumerate(os.listdir(dir)): + print('{}/{}'.format(index, N)) + img = Image.open(dir + i) + + mask = predict_img(net, img, gpu) + enc = encode(mask) + f.write('{},{}\n'.format(i, ' '.join(map(str, enc)))) + + +if __name__ == '__main__': + net = UNet(3, 1).cuda() + net.load_state_dict(torch.load('INTERRUPTED.pth')) + submit(net, True) diff --git a/utils.py b/utils.py index 1697178..807b522 100644 --- a/utils.py +++ b/utils.py @@ -53,3 +53,39 @@ def split_train_val(dataset, val_percent=0.05): def normalize(x): return x / 255 + + +def merge_masks(img1, img2, full_w): + w = img1.shape[1] + overlap = int(2 * w - full_w) + h = img1.shape[0] + + new = np.zeros((h, full_w), np.float32) + + margin = 0 + + new[:, :full_w//2+1] = img1[:, :full_w//2+1] + new[:, full_w//2+1:] = img2[:, -(full_w//2-1):] + #new[:, w-overlap+1+margin//2:-(w-overlap+margin//2)] = (img1[:, -overlap+margin:] + + # img2[:, :overlap-margin])/2 + + return new + + +def encode(mask): + """mask : HxW""" + flat = mask.transpose().reshape(-1) + enc = [] + i = 0 + while i < len(flat): # sorry python + if(flat[i]): + s = i + while(flat[i]): + i += 1 + e = i-1 + if(s != e): + enc.append(s) + enc.append(e - s + 1) + i += 1 + + return enc