mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
Final tweaks
Former-commit-id: 547d4580c776afa5782a08fc5288526da82a0972
This commit is contained in:
parent
932e1c3cf5
commit
3c14fc3e63
|
@ -1,3 +1,4 @@
|
|||
|
||||
import os
|
||||
from PIL import Image
|
||||
from predict import *
|
||||
|
@ -8,18 +9,20 @@ def submit(net, gpu=False):
|
|||
dir = 'data/test/'
|
||||
|
||||
N = len(list(os.listdir(dir)))
|
||||
with open('SUBMISSION.csv', 'w') as f:
|
||||
with open('SUBMISSION.csv', 'a') as f:
|
||||
|
||||
f.write('img,rle_mask\n')
|
||||
for index, i in enumerate(os.listdir(dir)):
|
||||
print('{}/{}'.format(index, N))
|
||||
|
||||
img = Image.open(dir + i)
|
||||
|
||||
mask = predict_img(net, img, gpu)
|
||||
enc = encode(mask)
|
||||
enc = rle_encode(mask)
|
||||
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = UNet(3, 1).cuda()
|
||||
net.load_state_dict(torch.load('INTERRUPTED.pth'))
|
||||
net.load_state_dict(torch.load('MODEL.pth'))
|
||||
submit(net, True)
|
||||
|
|
44
utils.py
44
utils.py
|
@ -72,20 +72,50 @@ def merge_masks(img1, img2, full_w):
|
|||
return new
|
||||
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def encode(mask):
|
||||
"""mask : HxW"""
|
||||
plt.imshow(mask.transpose())
|
||||
plt.show()
|
||||
flat = mask.transpose().reshape(-1)
|
||||
enc = []
|
||||
i = 0
|
||||
while i < len(flat): # sorry python
|
||||
if(flat[i]):
|
||||
i = 1
|
||||
|
||||
while i <= len(flat):
|
||||
if(flat[i-1]):
|
||||
s = i
|
||||
while(flat[i]):
|
||||
while(flat[i-1]):
|
||||
i += 1
|
||||
e = i-1
|
||||
if(s != e):
|
||||
enc.append(s)
|
||||
enc.append(e - s + 1)
|
||||
enc.append(s)
|
||||
enc.append(e - s + 1)
|
||||
i += 1
|
||||
|
||||
plt.imshow(decode(enc))
|
||||
plt.show()
|
||||
return enc
|
||||
|
||||
def decode(list):
|
||||
mask = np.zeros((1280*1920), np.bool)
|
||||
|
||||
for i, e in enumerate(list):
|
||||
if(i%2 == 0):
|
||||
mask[e-1:e-2+list[i+1]] = True
|
||||
|
||||
mask = mask.reshape(1920, 1280).transpose()
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def rle_encode(mask_image):
|
||||
pixels = mask_image.flatten()
|
||||
# We avoid issues with '1' at the start or end (at the corners of
|
||||
# the original image) by setting those pixels to '0' explicitly.
|
||||
# We do not expect these to be non-zero for an accurate mask,
|
||||
# so this should not harm the score.
|
||||
pixels[0] = 0
|
||||
pixels[-1] = 0
|
||||
runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
|
||||
runs[1::2] = runs[1::2] - runs[:-1:2]
|
||||
return runs
|
||||
|
|
Loading…
Reference in a new issue