REVA-QCAV/predict.py

42 lines
1.1 KiB
Python
Raw Normal View History

2017-08-21 16:00:07 +00:00
import torch
from utils import *
import torch.nn.functional as F
from PIL import Image
from unet_model import UNet
from torch.autograd import Variable
import matplotlib.pyplot as plt
from crf import dense_crf
def predict_img(net, full_img, gpu=False):
img = resize_and_crop(full_img)
left = get_square(img, 0)
right = get_square(img, 1)
right = normalize(right)
left = normalize(left)
right = np.transpose(right, axes=[2, 0, 1])
left = np.transpose(left, axes=[2, 0, 1])
X_l = torch.FloatTensor(left).unsqueeze(0)
X_r = torch.FloatTensor(right).unsqueeze(0)
if gpu:
X_l = Variable(X_l, volatile=True).cuda()
X_r = Variable(X_r, volatile=True).cuda()
else:
X_l = Variable(X_l, volatile=True)
X_r = Variable(X_r, volatile=True)
y_l = F.sigmoid(net(X_l))
y_r = F.sigmoid(net(X_r))
y_l = F.upsample_bilinear(y_l, scale_factor=2).data[0][0].cpu().numpy()
y_r = F.upsample_bilinear(y_r, scale_factor=2).data[0][0].cpu().numpy()
y = merge_masks(y_l, y_r, 1918)
yy = dense_crf(np.array(full_img).astype(np.uint8), y)
return yy > 0.5