Fix device in submit.py

Former-commit-id: 2eb8a6ff539ac0f4838ce1161b6d239e912ca007
This commit is contained in:
milesial 2020-07-14 11:37:05 -07:00 committed by GitHub
parent 63f889a4b0
commit cc2ac3db07

View file

@ -23,10 +23,10 @@ def rle_encode(mask_image):
return runs return runs
def submit(net, gpu=False): def submit(net):
"""Used for Kaggle submission: predicts and encode all test images""" """Used for Kaggle submission: predicts and encode all test images"""
dir = 'data/test/' dir = 'data/test/'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
N = len(list(os.listdir(dir))) N = len(list(os.listdir(dir)))
with open('SUBMISSION.csv', 'a') as f: with open('SUBMISSION.csv', 'a') as f:
f.write('img,rle_mask\n') f.write('img,rle_mask\n')
@ -35,7 +35,7 @@ def submit(net, gpu=False):
img = Image.open(dir + i) img = Image.open(dir + i)
mask = predict_img(net, img, gpu) mask = predict_img(net, img, device)
enc = rle_encode(mask) enc = rle_encode(mask)
f.write('{},{}\n'.format(i, ' '.join(map(str, enc)))) f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
@ -43,4 +43,4 @@ def submit(net, gpu=False):
if __name__ == '__main__': if __name__ == '__main__':
net = UNet(3, 1).cuda() net = UNet(3, 1).cuda()
net.load_state_dict(torch.load('MODEL.pth')) net.load_state_dict(torch.load('MODEL.pth'))
submit(net, True) submit(net)