Added simple predict + submit script

This commit is contained in:
milesial 2017-08-21 18:00:07 +02:00
parent fa40396fff
commit 74ee006a7a
6 changed files with 113 additions and 7 deletions

7
crf.py
View file

@ -8,7 +8,6 @@ def dense_crf(img, output_probs):
output_probs = np.expand_dims(output_probs, 0)
output_probs = np.append(1 - output_probs, output_probs, axis=0)
print(output_probs.shape)
d = dcrf.DenseCRF2D(w, h, 2)
U = -np.log(output_probs)
@ -19,10 +18,10 @@ def dense_crf(img, output_probs):
d.setUnaryEnergy(U)
d.addPairwiseGaussian(sxy=10, compat=3)
d.addPairwiseBilateral(sxy=50, srgb=20, rgbim=img, compat=10)
d.addPairwiseGaussian(sxy=20, compat=3)
d.addPairwiseBilateral(sxy=30, srgb=20, rgbim=img, compat=10)
Q = d.inference(30)
Q = d.inference(5)
Q = np.argmax(np.array(Q), axis=0).reshape((h, w))
return Q

View file

@ -32,7 +32,7 @@ def eval_net(net, dataset, gpu=False):
dice = dice_coeff(y_pred, y.float()).data[0]
tot += dice
if 0:
if 1:
X = X.data.squeeze(0).cpu().numpy()
X = np.transpose(X, axes=[1, 2, 0])
y = y.data.squeeze(0).cpu().numpy()
@ -45,12 +45,12 @@ def eval_net(net, dataset, gpu=False):
ax2 = fig.add_subplot(1, 4, 2)
ax2.imshow(y)
ax3 = fig.add_subplot(1, 4, 3)
ax3.imshow((y_pred > 0.6))
ax3.imshow((y_pred > 0.5))
Q = dense_crf(((X*255).round()).astype(np.uint8), y_pred)
ax4 = fig.add_subplot(1, 4, 4)
print(Q)
ax4.imshow(Q)
ax4.imshow(Q > 0.5)
plt.show()
return tot / i

View file

@ -40,3 +40,8 @@ def get_imgs_and_masks(ids, dir_img, dir_mask):
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif')
return zip(imgs_normalized, masks)
def get_full_img_and_mask(id, dir_img, dir_mask):
im = Image.open(dir_img + id + '.jpg')
mask = Image.open(dir_mask + id + '_mask.gif')
return np.array(im), np.array(mask)

41
predict.py Normal file
View file

@ -0,0 +1,41 @@
import torch
from utils import *
import torch.nn.functional as F
from PIL import Image
from unet_model import UNet
from torch.autograd import Variable
import matplotlib.pyplot as plt
from crf import dense_crf
def predict_img(net, full_img, gpu=False):
img = resize_and_crop(full_img)
left = get_square(img, 0)
right = get_square(img, 1)
right = normalize(right)
left = normalize(left)
right = np.transpose(right, axes=[2, 0, 1])
left = np.transpose(left, axes=[2, 0, 1])
X_l = torch.FloatTensor(left).unsqueeze(0)
X_r = torch.FloatTensor(right).unsqueeze(0)
if gpu:
X_l = Variable(X_l, volatile=True).cuda()
X_r = Variable(X_r, volatile=True).cuda()
else:
X_l = Variable(X_l, volatile=True)
X_r = Variable(X_r, volatile=True)
y_l = F.sigmoid(net(X_l))
y_r = F.sigmoid(net(X_r))
y_l = F.upsample_bilinear(y_l, scale_factor=2).data[0][0].cpu().numpy()
y_r = F.upsample_bilinear(y_r, scale_factor=2).data[0][0].cpu().numpy()
y = merge_masks(y_l, y_r, 1918)
yy = dense_crf(np.array(full_img).astype(np.uint8), y)
return yy > 0.5

25
submit.py Normal file
View file

@ -0,0 +1,25 @@
import os
from PIL import Image
from predict import *
from utils import encode
from unet_model import UNet
def submit(net, gpu=False):
dir = 'data/test/'
N = len(list(os.listdir(dir)))
with open('SUBMISSION.csv', 'w') as f:
f.write('img,rle_mask\n')
for index, i in enumerate(os.listdir(dir)):
print('{}/{}'.format(index, N))
img = Image.open(dir + i)
mask = predict_img(net, img, gpu)
enc = encode(mask)
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
if __name__ == '__main__':
net = UNet(3, 1).cuda()
net.load_state_dict(torch.load('INTERRUPTED.pth'))
submit(net, True)

View file

@ -53,3 +53,39 @@ def split_train_val(dataset, val_percent=0.05):
def normalize(x):
return x / 255
def merge_masks(img1, img2, full_w):
w = img1.shape[1]
overlap = int(2 * w - full_w)
h = img1.shape[0]
new = np.zeros((h, full_w), np.float32)
margin = 0
new[:, :full_w//2+1] = img1[:, :full_w//2+1]
new[:, full_w//2+1:] = img2[:, -(full_w//2-1):]
#new[:, w-overlap+1+margin//2:-(w-overlap+margin//2)] = (img1[:, -overlap+margin:] +
# img2[:, :overlap-margin])/2
return new
def encode(mask):
"""mask : HxW"""
flat = mask.transpose().reshape(-1)
enc = []
i = 0
while i < len(flat): # sorry python
if(flat[i]):
s = i
while(flat[i]):
i += 1
e = i-1
if(s != e):
enc.append(s)
enc.append(e - s + 1)
i += 1
return enc