Final tweaks

Former-commit-id: 547d4580c776afa5782a08fc5288526da82a0972
This commit is contained in:
milesial 2017-09-26 21:00:51 +02:00
parent 932e1c3cf5
commit 3c14fc3e63
2 changed files with 43 additions and 10 deletions

View file

@ -1,3 +1,4 @@
import os
from PIL import Image
from predict import *
@ -8,18 +9,20 @@ def submit(net, gpu=False):
dir = 'data/test/'
N = len(list(os.listdir(dir)))
with open('SUBMISSION.csv', 'w') as f:
with open('SUBMISSION.csv', 'a') as f:
f.write('img,rle_mask\n')
for index, i in enumerate(os.listdir(dir)):
print('{}/{}'.format(index, N))
img = Image.open(dir + i)
mask = predict_img(net, img, gpu)
enc = encode(mask)
enc = rle_encode(mask)
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
if __name__ == '__main__':
net = UNet(3, 1).cuda()
net.load_state_dict(torch.load('INTERRUPTED.pth'))
net.load_state_dict(torch.load('MODEL.pth'))
submit(net, True)

View file

@ -72,20 +72,50 @@ def merge_masks(img1, img2, full_w):
return new
import matplotlib.pyplot as plt
def encode(mask):
"""mask : HxW"""
plt.imshow(mask.transpose())
plt.show()
flat = mask.transpose().reshape(-1)
enc = []
i = 0
while i < len(flat): # sorry python
if(flat[i]):
i = 1
while i <= len(flat):
if(flat[i-1]):
s = i
while(flat[i]):
while(flat[i-1]):
i += 1
e = i-1
if(s != e):
enc.append(s)
enc.append(e - s + 1)
i += 1
plt.imshow(decode(enc))
plt.show()
return enc
def decode(list):
mask = np.zeros((1280*1920), np.bool)
for i, e in enumerate(list):
if(i%2 == 0):
mask[e-1:e-2+list[i+1]] = True
mask = mask.reshape(1920, 1280).transpose()
return mask
def rle_encode(mask_image):
pixels = mask_image.flatten()
# We avoid issues with '1' at the start or end (at the corners of
# the original image) by setting those pixels to '0' explicitly.
# We do not expect these to be non-zero for an accurate mask,
# so this should not harm the score.
pixels[0] = 0
pixels[-1] = 0
runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
runs[1::2] = runs[1::2] - runs[:-1:2]
return runs