mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
26 lines
684 B
Python
26 lines
684 B
Python
|
import os
|
||
|
from PIL import Image
|
||
|
from predict import *
|
||
|
from utils import encode
|
||
|
from unet_model import UNet
|
||
|
|
||
|
def submit(net, gpu=False):
|
||
|
dir = 'data/test/'
|
||
|
|
||
|
N = len(list(os.listdir(dir)))
|
||
|
with open('SUBMISSION.csv', 'w') as f:
|
||
|
f.write('img,rle_mask\n')
|
||
|
for index, i in enumerate(os.listdir(dir)):
|
||
|
print('{}/{}'.format(index, N))
|
||
|
img = Image.open(dir + i)
|
||
|
|
||
|
mask = predict_img(net, img, gpu)
|
||
|
enc = encode(mask)
|
||
|
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
net = UNet(3, 1).cuda()
|
||
|
net.load_state_dict(torch.load('INTERRUPTED.pth'))
|
||
|
submit(net, True)
|