From cc2ac3db07a2ca2b935d195be7ff92cf1996f33a Mon Sep 17 00:00:00 2001 From: milesial Date: Tue, 14 Jul 2020 11:37:05 -0700 Subject: [PATCH] Fix device in submit.py Former-commit-id: 2eb8a6ff539ac0f4838ce1161b6d239e912ca007 --- submit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/submit.py b/submit.py index 65a0df5..a5609cd 100644 --- a/submit.py +++ b/submit.py @@ -23,10 +23,10 @@ def rle_encode(mask_image): return runs -def submit(net, gpu=False): +def submit(net): """Used for Kaggle submission: predicts and encode all test images""" dir = 'data/test/' - + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') N = len(list(os.listdir(dir))) with open('SUBMISSION.csv', 'a') as f: f.write('img,rle_mask\n') @@ -35,7 +35,7 @@ def submit(net, gpu=False): img = Image.open(dir + i) - mask = predict_img(net, img, gpu) + mask = predict_img(net, img, device) enc = rle_encode(mask) f.write('{},{}\n'.format(i, ' '.join(map(str, enc)))) @@ -43,4 +43,4 @@ def submit(net, gpu=False): if __name__ == '__main__': net = UNet(3, 1).cuda() net.load_state_dict(torch.load('MODEL.pth')) - submit(net, True) + submit(net)