From 3c14fc3e63271ddb14ddbfd76786602e2352c0fe Mon Sep 17 00:00:00 2001 From: milesial Date: Tue, 26 Sep 2017 21:00:51 +0200 Subject: [PATCH] Final tweaks Former-commit-id: 547d4580c776afa5782a08fc5288526da82a0972 --- submit.py | 9 ++++++--- utils.py | 44 +++++++++++++++++++++++++++++++++++++------- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/submit.py b/submit.py index 14b0396..91aeefe 100644 --- a/submit.py +++ b/submit.py @@ -1,3 +1,4 @@ + import os from PIL import Image from predict import * @@ -8,18 +9,20 @@ def submit(net, gpu=False): dir = 'data/test/' N = len(list(os.listdir(dir))) - with open('SUBMISSION.csv', 'w') as f: + with open('SUBMISSION.csv', 'a') as f: + f.write('img,rle_mask\n') for index, i in enumerate(os.listdir(dir)): print('{}/{}'.format(index, N)) + img = Image.open(dir + i) mask = predict_img(net, img, gpu) - enc = encode(mask) + enc = rle_encode(mask) f.write('{},{}\n'.format(i, ' '.join(map(str, enc)))) if __name__ == '__main__': net = UNet(3, 1).cuda() - net.load_state_dict(torch.load('INTERRUPTED.pth')) + net.load_state_dict(torch.load('MODEL.pth')) submit(net, True) diff --git a/utils.py b/utils.py index 807b522..a3b2461 100644 --- a/utils.py +++ b/utils.py @@ -72,20 +72,50 @@ def merge_masks(img1, img2, full_w): return new +import matplotlib.pyplot as plt + def encode(mask): """mask : HxW""" + plt.imshow(mask.transpose()) + plt.show() flat = mask.transpose().reshape(-1) enc = [] - i = 0 - while i < len(flat): # sorry python - if(flat[i]): + i = 1 + + while i <= len(flat): + if(flat[i-1]): s = i - while(flat[i]): + while(flat[i-1]): i += 1 e = i-1 - if(s != e): - enc.append(s) - enc.append(e - s + 1) + enc.append(s) + enc.append(e - s + 1) i += 1 + plt.imshow(decode(enc)) + plt.show() return enc + +def decode(list): + mask = np.zeros((1280*1920), np.bool) + + for i, e in enumerate(list): + if(i%2 == 0): + mask[e-1:e-2+list[i+1]] = True + + mask = mask.reshape(1920, 1280).transpose() + + return mask + + +def rle_encode(mask_image): + pixels = mask_image.flatten() + # We avoid issues with '1' at the start or end (at the corners of + # the original image) by setting those pixels to '0' explicitly. + # We do not expect these to be non-zero for an accurate mask, + # so this should not harm the score. + pixels[0] = 0 + pixels[-1] = 0 + runs = np.where(pixels[1:] != pixels[:-1])[0] + 2 + runs[1::2] = runs[1::2] - runs[:-1:2] + return runs