projet-long/submit.py
milesial 8b614c3e31 Modified to take any image size (with even width, height > width/2)
Former-commit-id: 2751e6a3df45c1527376a4697d3804d683095d83
2017-11-30 07:19:52 +01:00

30 lines
746 B
Python

# used to predict all test images and encode results in a csv file
import os
from PIL import Image
from predict import *
from utils import encode
from unet import UNet
def submit(net, gpu=False):
dir = 'data/test/'
N = len(list(os.listdir(dir)))
with open('SUBMISSION.csv', 'a') as f:
f.write('img,rle_mask\n')
for index, i in enumerate(os.listdir(dir)):
print('{}/{}'.format(index, N))
img = Image.open(dir + i)
mask = predict_img(net, img, gpu)
enc = rle_encode(mask)
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
if __name__ == '__main__':
net = UNet(3, 1).cuda()
net.load_state_dict(torch.load('MODEL.pth'))
submit(net, True)