mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
Added simple predict + submit script
This commit is contained in:
parent
fa40396fff
commit
74ee006a7a
7
crf.py
7
crf.py
|
@ -8,7 +8,6 @@ def dense_crf(img, output_probs):
|
||||||
|
|
||||||
output_probs = np.expand_dims(output_probs, 0)
|
output_probs = np.expand_dims(output_probs, 0)
|
||||||
output_probs = np.append(1 - output_probs, output_probs, axis=0)
|
output_probs = np.append(1 - output_probs, output_probs, axis=0)
|
||||||
print(output_probs.shape)
|
|
||||||
|
|
||||||
d = dcrf.DenseCRF2D(w, h, 2)
|
d = dcrf.DenseCRF2D(w, h, 2)
|
||||||
U = -np.log(output_probs)
|
U = -np.log(output_probs)
|
||||||
|
@ -19,10 +18,10 @@ def dense_crf(img, output_probs):
|
||||||
|
|
||||||
d.setUnaryEnergy(U)
|
d.setUnaryEnergy(U)
|
||||||
|
|
||||||
d.addPairwiseGaussian(sxy=10, compat=3)
|
d.addPairwiseGaussian(sxy=20, compat=3)
|
||||||
d.addPairwiseBilateral(sxy=50, srgb=20, rgbim=img, compat=10)
|
d.addPairwiseBilateral(sxy=30, srgb=20, rgbim=img, compat=10)
|
||||||
|
|
||||||
Q = d.inference(30)
|
Q = d.inference(5)
|
||||||
Q = np.argmax(np.array(Q), axis=0).reshape((h, w))
|
Q = np.argmax(np.array(Q), axis=0).reshape((h, w))
|
||||||
|
|
||||||
return Q
|
return Q
|
||||||
|
|
6
eval.py
6
eval.py
|
@ -32,7 +32,7 @@ def eval_net(net, dataset, gpu=False):
|
||||||
dice = dice_coeff(y_pred, y.float()).data[0]
|
dice = dice_coeff(y_pred, y.float()).data[0]
|
||||||
tot += dice
|
tot += dice
|
||||||
|
|
||||||
if 0:
|
if 1:
|
||||||
X = X.data.squeeze(0).cpu().numpy()
|
X = X.data.squeeze(0).cpu().numpy()
|
||||||
X = np.transpose(X, axes=[1, 2, 0])
|
X = np.transpose(X, axes=[1, 2, 0])
|
||||||
y = y.data.squeeze(0).cpu().numpy()
|
y = y.data.squeeze(0).cpu().numpy()
|
||||||
|
@ -45,12 +45,12 @@ def eval_net(net, dataset, gpu=False):
|
||||||
ax2 = fig.add_subplot(1, 4, 2)
|
ax2 = fig.add_subplot(1, 4, 2)
|
||||||
ax2.imshow(y)
|
ax2.imshow(y)
|
||||||
ax3 = fig.add_subplot(1, 4, 3)
|
ax3 = fig.add_subplot(1, 4, 3)
|
||||||
ax3.imshow((y_pred > 0.6))
|
ax3.imshow((y_pred > 0.5))
|
||||||
|
|
||||||
|
|
||||||
Q = dense_crf(((X*255).round()).astype(np.uint8), y_pred)
|
Q = dense_crf(((X*255).round()).astype(np.uint8), y_pred)
|
||||||
ax4 = fig.add_subplot(1, 4, 4)
|
ax4 = fig.add_subplot(1, 4, 4)
|
||||||
print(Q)
|
print(Q)
|
||||||
ax4.imshow(Q)
|
ax4.imshow(Q > 0.5)
|
||||||
plt.show()
|
plt.show()
|
||||||
return tot / i
|
return tot / i
|
||||||
|
|
5
load.py
5
load.py
|
@ -40,3 +40,8 @@ def get_imgs_and_masks(ids, dir_img, dir_mask):
|
||||||
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif')
|
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif')
|
||||||
|
|
||||||
return zip(imgs_normalized, masks)
|
return zip(imgs_normalized, masks)
|
||||||
|
|
||||||
|
def get_full_img_and_mask(id, dir_img, dir_mask):
|
||||||
|
im = Image.open(dir_img + id + '.jpg')
|
||||||
|
mask = Image.open(dir_mask + id + '_mask.gif')
|
||||||
|
return np.array(im), np.array(mask)
|
||||||
|
|
41
predict.py
Normal file
41
predict.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
import torch
|
||||||
|
from utils import *
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
from unet_model import UNet
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from crf import dense_crf
|
||||||
|
|
||||||
|
|
||||||
|
def predict_img(net, full_img, gpu=False):
|
||||||
|
img = resize_and_crop(full_img)
|
||||||
|
|
||||||
|
left = get_square(img, 0)
|
||||||
|
right = get_square(img, 1)
|
||||||
|
|
||||||
|
right = normalize(right)
|
||||||
|
left = normalize(left)
|
||||||
|
|
||||||
|
right = np.transpose(right, axes=[2, 0, 1])
|
||||||
|
left = np.transpose(left, axes=[2, 0, 1])
|
||||||
|
|
||||||
|
X_l = torch.FloatTensor(left).unsqueeze(0)
|
||||||
|
X_r = torch.FloatTensor(right).unsqueeze(0)
|
||||||
|
|
||||||
|
if gpu:
|
||||||
|
X_l = Variable(X_l, volatile=True).cuda()
|
||||||
|
X_r = Variable(X_r, volatile=True).cuda()
|
||||||
|
else:
|
||||||
|
X_l = Variable(X_l, volatile=True)
|
||||||
|
X_r = Variable(X_r, volatile=True)
|
||||||
|
|
||||||
|
y_l = F.sigmoid(net(X_l))
|
||||||
|
y_r = F.sigmoid(net(X_r))
|
||||||
|
y_l = F.upsample_bilinear(y_l, scale_factor=2).data[0][0].cpu().numpy()
|
||||||
|
y_r = F.upsample_bilinear(y_r, scale_factor=2).data[0][0].cpu().numpy()
|
||||||
|
|
||||||
|
y = merge_masks(y_l, y_r, 1918)
|
||||||
|
yy = dense_crf(np.array(full_img).astype(np.uint8), y)
|
||||||
|
|
||||||
|
return yy > 0.5
|
25
submit.py
Normal file
25
submit.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
from predict import *
|
||||||
|
from utils import encode
|
||||||
|
from unet_model import UNet
|
||||||
|
|
||||||
|
def submit(net, gpu=False):
|
||||||
|
dir = 'data/test/'
|
||||||
|
|
||||||
|
N = len(list(os.listdir(dir)))
|
||||||
|
with open('SUBMISSION.csv', 'w') as f:
|
||||||
|
f.write('img,rle_mask\n')
|
||||||
|
for index, i in enumerate(os.listdir(dir)):
|
||||||
|
print('{}/{}'.format(index, N))
|
||||||
|
img = Image.open(dir + i)
|
||||||
|
|
||||||
|
mask = predict_img(net, img, gpu)
|
||||||
|
enc = encode(mask)
|
||||||
|
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
net = UNet(3, 1).cuda()
|
||||||
|
net.load_state_dict(torch.load('INTERRUPTED.pth'))
|
||||||
|
submit(net, True)
|
36
utils.py
36
utils.py
|
@ -53,3 +53,39 @@ def split_train_val(dataset, val_percent=0.05):
|
||||||
|
|
||||||
def normalize(x):
|
def normalize(x):
|
||||||
return x / 255
|
return x / 255
|
||||||
|
|
||||||
|
|
||||||
|
def merge_masks(img1, img2, full_w):
|
||||||
|
w = img1.shape[1]
|
||||||
|
overlap = int(2 * w - full_w)
|
||||||
|
h = img1.shape[0]
|
||||||
|
|
||||||
|
new = np.zeros((h, full_w), np.float32)
|
||||||
|
|
||||||
|
margin = 0
|
||||||
|
|
||||||
|
new[:, :full_w//2+1] = img1[:, :full_w//2+1]
|
||||||
|
new[:, full_w//2+1:] = img2[:, -(full_w//2-1):]
|
||||||
|
#new[:, w-overlap+1+margin//2:-(w-overlap+margin//2)] = (img1[:, -overlap+margin:] +
|
||||||
|
# img2[:, :overlap-margin])/2
|
||||||
|
|
||||||
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
def encode(mask):
|
||||||
|
"""mask : HxW"""
|
||||||
|
flat = mask.transpose().reshape(-1)
|
||||||
|
enc = []
|
||||||
|
i = 0
|
||||||
|
while i < len(flat): # sorry python
|
||||||
|
if(flat[i]):
|
||||||
|
s = i
|
||||||
|
while(flat[i]):
|
||||||
|
i += 1
|
||||||
|
e = i-1
|
||||||
|
if(s != e):
|
||||||
|
enc.append(s)
|
||||||
|
enc.append(e - s + 1)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return enc
|
||||||
|
|
Loading…
Reference in a new issue