diff --git a/submit.py b/submit.py index 65a0df5..a5609cd 100644 --- a/submit.py +++ b/submit.py @@ -23,10 +23,10 @@ def rle_encode(mask_image): return runs -def submit(net, gpu=False): +def submit(net): """Used for Kaggle submission: predicts and encode all test images""" dir = 'data/test/' - + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') N = len(list(os.listdir(dir))) with open('SUBMISSION.csv', 'a') as f: f.write('img,rle_mask\n') @@ -35,7 +35,7 @@ def submit(net, gpu=False): img = Image.open(dir + i) - mask = predict_img(net, img, gpu) + mask = predict_img(net, img, device) enc = rle_encode(mask) f.write('{},{}\n'.format(i, ' '.join(map(str, enc)))) @@ -43,4 +43,4 @@ def submit(net, gpu=False): if __name__ == '__main__': net = UNet(3, 1).cuda() net.load_state_dict(torch.load('MODEL.pth')) - submit(net, True) + submit(net)