mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
Final tweaks
Former-commit-id: 547d4580c776afa5782a08fc5288526da82a0972
This commit is contained in:
parent
932e1c3cf5
commit
3c14fc3e63
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from predict import *
|
from predict import *
|
||||||
|
@ -8,18 +9,20 @@ def submit(net, gpu=False):
|
||||||
dir = 'data/test/'
|
dir = 'data/test/'
|
||||||
|
|
||||||
N = len(list(os.listdir(dir)))
|
N = len(list(os.listdir(dir)))
|
||||||
with open('SUBMISSION.csv', 'w') as f:
|
with open('SUBMISSION.csv', 'a') as f:
|
||||||
|
|
||||||
f.write('img,rle_mask\n')
|
f.write('img,rle_mask\n')
|
||||||
for index, i in enumerate(os.listdir(dir)):
|
for index, i in enumerate(os.listdir(dir)):
|
||||||
print('{}/{}'.format(index, N))
|
print('{}/{}'.format(index, N))
|
||||||
|
|
||||||
img = Image.open(dir + i)
|
img = Image.open(dir + i)
|
||||||
|
|
||||||
mask = predict_img(net, img, gpu)
|
mask = predict_img(net, img, gpu)
|
||||||
enc = encode(mask)
|
enc = rle_encode(mask)
|
||||||
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
|
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
net = UNet(3, 1).cuda()
|
net = UNet(3, 1).cuda()
|
||||||
net.load_state_dict(torch.load('INTERRUPTED.pth'))
|
net.load_state_dict(torch.load('MODEL.pth'))
|
||||||
submit(net, True)
|
submit(net, True)
|
||||||
|
|
40
utils.py
40
utils.py
|
@ -72,20 +72,50 @@ def merge_masks(img1, img2, full_w):
|
||||||
return new
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
def encode(mask):
|
def encode(mask):
|
||||||
"""mask : HxW"""
|
"""mask : HxW"""
|
||||||
|
plt.imshow(mask.transpose())
|
||||||
|
plt.show()
|
||||||
flat = mask.transpose().reshape(-1)
|
flat = mask.transpose().reshape(-1)
|
||||||
enc = []
|
enc = []
|
||||||
i = 0
|
i = 1
|
||||||
while i < len(flat): # sorry python
|
|
||||||
if(flat[i]):
|
while i <= len(flat):
|
||||||
|
if(flat[i-1]):
|
||||||
s = i
|
s = i
|
||||||
while(flat[i]):
|
while(flat[i-1]):
|
||||||
i += 1
|
i += 1
|
||||||
e = i-1
|
e = i-1
|
||||||
if(s != e):
|
|
||||||
enc.append(s)
|
enc.append(s)
|
||||||
enc.append(e - s + 1)
|
enc.append(e - s + 1)
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
plt.imshow(decode(enc))
|
||||||
|
plt.show()
|
||||||
return enc
|
return enc
|
||||||
|
|
||||||
|
def decode(list):
|
||||||
|
mask = np.zeros((1280*1920), np.bool)
|
||||||
|
|
||||||
|
for i, e in enumerate(list):
|
||||||
|
if(i%2 == 0):
|
||||||
|
mask[e-1:e-2+list[i+1]] = True
|
||||||
|
|
||||||
|
mask = mask.reshape(1920, 1280).transpose()
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def rle_encode(mask_image):
|
||||||
|
pixels = mask_image.flatten()
|
||||||
|
# We avoid issues with '1' at the start or end (at the corners of
|
||||||
|
# the original image) by setting those pixels to '0' explicitly.
|
||||||
|
# We do not expect these to be non-zero for an accurate mask,
|
||||||
|
# so this should not harm the score.
|
||||||
|
pixels[0] = 0
|
||||||
|
pixels[-1] = 0
|
||||||
|
runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
|
||||||
|
runs[1::2] = runs[1::2] - runs[:-1:2]
|
||||||
|
return runs
|
||||||
|
|
Loading…
Reference in a new issue