REVA-QCAV/submit.py

34 lines
834 B
Python
Raw Normal View History

""" Submit code specific to the kaggle challenge"""
import os
import torch
from PIL import Image
from predict import predict_img
from unet import UNet
from utils import rle_encode
2017-08-21 16:00:07 +00:00
2017-08-21 16:00:07 +00:00
def submit(net, gpu=False):
"""Used for Kaggle submission: predicts and encode all test images"""
2017-08-21 16:00:07 +00:00
dir = 'data/test/'
N = len(list(os.listdir(dir)))
with open('SUBMISSION.csv', 'a') as f:
2017-08-21 16:00:07 +00:00
f.write('img,rle_mask\n')
for index, i in enumerate(os.listdir(dir)):
print('{}/{}'.format(index, N))
2017-08-21 16:00:07 +00:00
img = Image.open(dir + i)
mask = predict_img(net, img, gpu)
enc = rle_encode(mask)
2017-08-21 16:00:07 +00:00
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
if __name__ == '__main__':
net = UNet(3, 1).cuda()
net.load_state_dict(torch.load('MODEL.pth'))
2017-08-21 16:00:07 +00:00
submit(net, True)