Fix device in submit.py
Former-commit-id: 2eb8a6ff539ac0f4838ce1161b6d239e912ca007
This commit is contained in:
parent
63f889a4b0
commit
cc2ac3db07
|
@ -23,10 +23,10 @@ def rle_encode(mask_image):
|
||||||
return runs
|
return runs
|
||||||
|
|
||||||
|
|
||||||
def submit(net, gpu=False):
|
def submit(net):
|
||||||
"""Used for Kaggle submission: predicts and encode all test images"""
|
"""Used for Kaggle submission: predicts and encode all test images"""
|
||||||
dir = 'data/test/'
|
dir = 'data/test/'
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
N = len(list(os.listdir(dir)))
|
N = len(list(os.listdir(dir)))
|
||||||
with open('SUBMISSION.csv', 'a') as f:
|
with open('SUBMISSION.csv', 'a') as f:
|
||||||
f.write('img,rle_mask\n')
|
f.write('img,rle_mask\n')
|
||||||
|
@ -35,7 +35,7 @@ def submit(net, gpu=False):
|
||||||
|
|
||||||
img = Image.open(dir + i)
|
img = Image.open(dir + i)
|
||||||
|
|
||||||
mask = predict_img(net, img, gpu)
|
mask = predict_img(net, img, device)
|
||||||
enc = rle_encode(mask)
|
enc = rle_encode(mask)
|
||||||
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
|
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
|
||||||
|
|
||||||
|
@ -43,4 +43,4 @@ def submit(net, gpu=False):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
net = UNet(3, 1).cuda()
|
net = UNet(3, 1).cuda()
|
||||||
net.load_state_dict(torch.load('MODEL.pth'))
|
net.load_state_dict(torch.load('MODEL.pth'))
|
||||||
submit(net, True)
|
submit(net)
|
||||||
|
|
Loading…
Reference in a new issue