2019-10-24 19:37:21 +00:00
|
|
|
""" Submit code specific to the kaggle challenge"""
|
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
import os
|
2017-09-26 19:00:51 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
import torch
|
2019-10-24 19:37:21 +00:00
|
|
|
from PIL import Image
|
2019-11-23 16:56:14 +00:00
|
|
|
import numpy as np
|
2018-06-08 17:27:32 +00:00
|
|
|
|
|
|
|
from predict import predict_img
|
2017-11-30 05:45:19 +00:00
|
|
|
from unet import UNet
|
2019-11-23 16:56:14 +00:00
|
|
|
|
|
|
|
# credits to https://stackoverflow.com/users/6076729/manuel-lagunas
|
|
|
|
def rle_encode(mask_image):
|
|
|
|
pixels = mask_image.flatten()
|
|
|
|
# We avoid issues with '1' at the start or end (at the corners of
|
|
|
|
# the original image) by setting those pixels to '0' explicitly.
|
|
|
|
# We do not expect these to be non-zero for an accurate mask,
|
|
|
|
# so this should not harm the score.
|
|
|
|
pixels[0] = 0
|
|
|
|
pixels[-1] = 0
|
|
|
|
runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
|
|
|
|
runs[1::2] = runs[1::2] - runs[:-1:2]
|
|
|
|
return runs
|
2017-08-21 16:00:07 +00:00
|
|
|
|
2018-04-09 03:15:24 +00:00
|
|
|
|
2020-07-14 18:37:05 +00:00
|
|
|
def submit(net):
|
2018-06-08 17:27:32 +00:00
|
|
|
"""Used for Kaggle submission: predicts and encode all test images"""
|
2017-08-21 16:00:07 +00:00
|
|
|
dir = 'data/test/'
|
2020-07-14 18:37:05 +00:00
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
2017-08-21 16:00:07 +00:00
|
|
|
N = len(list(os.listdir(dir)))
|
2017-09-26 19:00:51 +00:00
|
|
|
with open('SUBMISSION.csv', 'a') as f:
|
2017-08-21 16:00:07 +00:00
|
|
|
f.write('img,rle_mask\n')
|
|
|
|
for index, i in enumerate(os.listdir(dir)):
|
|
|
|
print('{}/{}'.format(index, N))
|
2017-09-26 19:00:51 +00:00
|
|
|
|
2017-08-21 16:00:07 +00:00
|
|
|
img = Image.open(dir + i)
|
|
|
|
|
2020-07-14 18:37:05 +00:00
|
|
|
mask = predict_img(net, img, device)
|
2017-09-26 19:00:51 +00:00
|
|
|
enc = rle_encode(mask)
|
2017-08-21 16:00:07 +00:00
|
|
|
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
net = UNet(3, 1).cuda()
|
2017-09-26 19:00:51 +00:00
|
|
|
net.load_state_dict(torch.load('MODEL.pth'))
|
2020-07-14 18:37:05 +00:00
|
|
|
submit(net)
|