Added simple eval and test CRF
This commit is contained in:
parent
4063565295
commit
fa40396fff
28
crf.py
Normal file
28
crf.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
import numpy as np
|
||||
import pydensecrf.densecrf as dcrf
|
||||
|
||||
|
||||
def dense_crf(img, output_probs):
|
||||
h = output_probs.shape[0]
|
||||
w = output_probs.shape[1]
|
||||
|
||||
output_probs = np.expand_dims(output_probs, 0)
|
||||
output_probs = np.append(1 - output_probs, output_probs, axis=0)
|
||||
print(output_probs.shape)
|
||||
|
||||
d = dcrf.DenseCRF2D(w, h, 2)
|
||||
U = -np.log(output_probs)
|
||||
U = U.reshape((2, -1))
|
||||
U = np.ascontiguousarray(U)
|
||||
img = np.ascontiguousarray(img)
|
||||
|
||||
|
||||
d.setUnaryEnergy(U)
|
||||
|
||||
d.addPairwiseGaussian(sxy=10, compat=3)
|
||||
d.addPairwiseBilateral(sxy=50, srgb=20, rgbim=img, compat=10)
|
||||
|
||||
Q = d.inference(30)
|
||||
Q = np.argmax(np.array(Q), axis=0).reshape((h, w))
|
||||
|
||||
return Q
|
56
eval.py
Normal file
56
eval.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
import torch
|
||||
from myloss import dice_coeff
|
||||
import numpy as np
|
||||
from torch.autograd import Variable
|
||||
from data_vis import plot_img_mask
|
||||
import matplotlib.pyplot as plt
|
||||
import torch.nn.functional as F
|
||||
from crf import dense_crf
|
||||
|
||||
|
||||
def eval_net(net, dataset, gpu=False):
|
||||
tot = 0
|
||||
for i, b in enumerate(dataset):
|
||||
X = b[0]
|
||||
y = b[1]
|
||||
|
||||
X = torch.FloatTensor(X).unsqueeze(0)
|
||||
y = torch.ByteTensor(y).unsqueeze(0)
|
||||
|
||||
if gpu:
|
||||
X = Variable(X, volatile=True).cuda()
|
||||
y = Variable(y, volatile=True).cuda()
|
||||
else:
|
||||
X = Variable(X, volatile=True)
|
||||
y = Variable(y, volatile=True)
|
||||
|
||||
y_pred = net(X)
|
||||
|
||||
y_pred = (F.sigmoid(y_pred) > 0.6).float()
|
||||
# y_pred = F.sigmoid(y_pred).float()
|
||||
|
||||
dice = dice_coeff(y_pred, y.float()).data[0]
|
||||
tot += dice
|
||||
|
||||
if 0:
|
||||
X = X.data.squeeze(0).cpu().numpy()
|
||||
X = np.transpose(X, axes=[1, 2, 0])
|
||||
y = y.data.squeeze(0).cpu().numpy()
|
||||
y_pred = y_pred.data.squeeze(0).squeeze(0).cpu().numpy()
|
||||
print(y_pred.shape)
|
||||
|
||||
fig = plt.figure()
|
||||
ax1 = fig.add_subplot(1, 4, 1)
|
||||
ax1.imshow(X)
|
||||
ax2 = fig.add_subplot(1, 4, 2)
|
||||
ax2.imshow(y)
|
||||
ax3 = fig.add_subplot(1, 4, 3)
|
||||
ax3.imshow((y_pred > 0.6))
|
||||
|
||||
|
||||
Q = dense_crf(((X*255).round()).astype(np.uint8), y_pred)
|
||||
ax4 = fig.add_subplot(1, 4, 4)
|
||||
print(Q)
|
||||
ax4.imshow(Q)
|
||||
plt.show()
|
||||
return tot / i
|
105
main.py
105
main.py
|
@ -8,17 +8,20 @@ from torch import optim
|
|||
#data manipulation
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import cv2
|
||||
import PIL
|
||||
|
||||
#load files
|
||||
import os
|
||||
|
||||
#data vis
|
||||
#data visualization
|
||||
from data_vis import plot_img_mask
|
||||
from utils import *
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
#quit after interrupt
|
||||
import sys
|
||||
|
||||
|
||||
|
||||
dir = 'data'
|
||||
ids = []
|
||||
|
@ -33,69 +36,71 @@ np.random.shuffle(ids)
|
|||
|
||||
|
||||
net = UNet(3, 1)
|
||||
net.cuda()
|
||||
|
||||
optimizer = optim.Adam(net.parameters(), lr=0.001)
|
||||
criterion = DiceLoss()
|
||||
def train(net):
|
||||
optimizer = optim.Adam(net.parameters(), lr=1)
|
||||
criterion = DiceLoss()
|
||||
|
||||
dataset = []
|
||||
epochs = 5
|
||||
for epoch in range(epochs):
|
||||
print('epoch {}/{}...'.format(epoch+1, epochs))
|
||||
l = 0
|
||||
epochs = 5
|
||||
for epoch in range(epochs):
|
||||
print('epoch {}/{}...'.format(epoch+1, epochs))
|
||||
l = 0
|
||||
|
||||
for i, c in enumerate(ids):
|
||||
id = c[0]
|
||||
pos = c[1]
|
||||
im = PIL.Image.open(dir + '/train/' + id + '.jpg')
|
||||
im = resize_and_crop(im)
|
||||
for i, c in enumerate(ids):
|
||||
id = c[0]
|
||||
pos = c[1]
|
||||
im = PIL.Image.open(dir + '/train/' + id + '.jpg')
|
||||
im = resize_and_crop(im)
|
||||
|
||||
ma = PIL.Image.open(dir + '/train_masks/' + id + '_mask.gif')
|
||||
ma = resize_and_crop(ma)
|
||||
ma = PIL.Image.open(dir + '/train_masks/' + id + '_mask.gif')
|
||||
ma = resize_and_crop(ma)
|
||||
|
||||
left, right = split_into_squares(np.array(im))
|
||||
left_m, right_m = split_into_squares(np.array(ma))
|
||||
left, right = split_into_squares(np.array(im))
|
||||
left_m, right_m = split_into_squares(np.array(ma))
|
||||
|
||||
if pos == 0:
|
||||
X = left
|
||||
y = left_m
|
||||
else:
|
||||
X = right
|
||||
y = right_m
|
||||
if pos == 0:
|
||||
X = left
|
||||
y = left_m
|
||||
else:
|
||||
X = right
|
||||
y = right_m
|
||||
|
||||
|
||||
X = np.transpose(X, axes=[2, 0, 1])
|
||||
X = torch.FloatTensor(X / 255).unsqueeze(0)
|
||||
y = Variable(torch.ByteTensor(y))
|
||||
X = np.transpose(X, axes=[2, 0, 1])
|
||||
X = torch.FloatTensor(X / 255).unsqueeze(0).cuda()
|
||||
y = Variable(torch.ByteTensor(y)).cuda()
|
||||
|
||||
X = Variable(X, requires_grad=False)
|
||||
X = Variable(X).cuda()
|
||||
|
||||
optimizer.zero_grad()
|
||||
optimizer.zero_grad()
|
||||
|
||||
y_pred = net(X).squeeze(1)
|
||||
y_pred = net(X).squeeze(1)
|
||||
|
||||
|
||||
loss = criterion(y_pred, y.unsqueeze(0).float())
|
||||
loss = criterion(y_pred, y.unsqueeze(0).float())
|
||||
|
||||
l += loss.data[0]
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
l += loss.data[0]
|
||||
loss.backward()
|
||||
if i%10 == 0:
|
||||
optimizer.step()
|
||||
print('Stepped')
|
||||
|
||||
print('{0:.4f}%.'.format(i/len(ids)*100, end=''))
|
||||
print('{0:.4f}%\t\t{1:.6f}'.format(i/len(ids)*100, loss.data[0]))
|
||||
|
||||
print('Loss : {}'.format(l))
|
||||
l = l / len(ids)
|
||||
print('Loss : {}'.format(l))
|
||||
torch.save(net.state_dict(), 'MODEL_EPOCH{}_LOSS{}.pth'.format(epoch+1, l))
|
||||
print('Saved')
|
||||
|
||||
try:
|
||||
net.load_state_dict(torch.load('MODEL_INTERRUPTED.pth'))
|
||||
train(net)
|
||||
|
||||
#%%
|
||||
|
||||
|
||||
|
||||
|
||||
#net = UNet(3, 2)
|
||||
|
||||
#x = Variable(torch.FloatTensor(np.random.randn(1, 3, 640, 640)))
|
||||
|
||||
#y = net(x)
|
||||
|
||||
|
||||
#plt.imshow(y[0])
|
||||
#plt.show()
|
||||
except KeyboardInterrupt:
|
||||
print('Interrupted')
|
||||
torch.save(net.state_dict(), 'MODEL_INTERRUPTED.pth')
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0)
|
||||
|
|
77
train.py
77
train.py
|
@ -1,13 +1,19 @@
|
|||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
from load import *
|
||||
from data_vis import *
|
||||
from utils import split_train_val, batch
|
||||
from myloss import DiceLoss
|
||||
from eval import eval_net
|
||||
from unet_model import UNet
|
||||
from torch.autograd import Variable
|
||||
from torch import optim
|
||||
from optparse import OptionParser
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
|
||||
|
@ -39,14 +45,21 @@ def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
|
|||
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
|
||||
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask)
|
||||
|
||||
optimizer = optim.Adam(net.parameters(), lr=lr)
|
||||
criterion = DiceLoss()
|
||||
optimizer = optim.SGD(net.parameters(),
|
||||
lr=lr, momentum=0.9, weight_decay=0.0005)
|
||||
criterion = nn.BCELoss()
|
||||
|
||||
for epoch in range(epochs):
|
||||
print('Starting epoch {}/{}.'.format(epoch+1, epochs))
|
||||
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
|
||||
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask)
|
||||
|
||||
epoch_loss = 0
|
||||
|
||||
if 1:
|
||||
val_dice = eval_net(net, val, gpu)
|
||||
print('Validation Dice Coeff: {}'.format(val_dice))
|
||||
|
||||
for i, b in enumerate(batch(train, batch_size)):
|
||||
X = np.array([i[0] for i in b])
|
||||
y = np.array([i[1] for i in b])
|
||||
|
@ -61,17 +74,22 @@ def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
|
|||
X = Variable(X)
|
||||
y = Variable(y)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
y_pred = net(X)
|
||||
probs = F.sigmoid(y_pred)
|
||||
probs_flat = probs.view(-1)
|
||||
|
||||
loss = criterion(y_pred, y.float())
|
||||
y_flat = y.view(-1)
|
||||
|
||||
loss = criterion(probs_flat, y_flat.float())
|
||||
epoch_loss += loss.data[0]
|
||||
|
||||
print('{0:.4f} --- loss: {1:.6f}'.format(i*batch_size/N_train,
|
||||
loss.data[0]))
|
||||
loss.data[0]))
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
print('Epoch finished ! Loss: {}'.format(epoch_loss/i))
|
||||
|
@ -83,23 +101,38 @@ def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
|
|||
print('Checkpoint {} saved !'.format(epoch+1))
|
||||
|
||||
|
||||
parser = OptionParser()
|
||||
parser.add_option("-e", "--epochs", dest="epochs", default=5, type="int",
|
||||
help="number of epochs")
|
||||
parser.add_option("-b", "--batch-size", dest="batchsize", default=10,
|
||||
type="int", help="batch size")
|
||||
parser.add_option("-l", "--learning-rate", dest="lr", default=0.1,
|
||||
type="int", help="learning rate")
|
||||
parser.add_option("-g", "--gpu", action="store_true", dest="gpu",
|
||||
default=False, help="use cuda")
|
||||
parser.add_option("-n", "--ngpu", action="store_false", dest="gpu",
|
||||
default=False, help="use cuda")
|
||||
if __name__ == '__main__':
|
||||
parser = OptionParser()
|
||||
parser.add_option('-e', '--epochs', dest='epochs', default=5, type='int',
|
||||
help='number of epochs')
|
||||
parser.add_option('-b', '--batch-size', dest='batchsize', default=10,
|
||||
type='int', help='batch size')
|
||||
parser.add_option('-l', '--learning-rate', dest='lr', default=0.1,
|
||||
type='float', help='learning rate')
|
||||
parser.add_option('-g', '--gpu', action='store_true', dest='gpu',
|
||||
default=False, help='use cuda')
|
||||
parser.add_option('-c', '--load', dest='load',
|
||||
default=False, help='load file model')
|
||||
|
||||
(options, args) = parser.parse_args()
|
||||
|
||||
(options, args) = parser.parse_args()
|
||||
net = UNet(3, 1)
|
||||
|
||||
net = UNet(3, 1)
|
||||
if options.gpu:
|
||||
net.cuda()
|
||||
if options.load:
|
||||
net.load_state_dict(torch.load(options.load))
|
||||
print('Model loaded from {}'.format(options.load))
|
||||
|
||||
train_net(net, options.epochs, options.batchsize, options.lr, gpu=options.gpu)
|
||||
if options.gpu:
|
||||
net.cuda()
|
||||
cudnn.benchmark = True
|
||||
|
||||
try:
|
||||
train_net(net, options.epochs, options.batchsize, options.lr,
|
||||
gpu=options.gpu)
|
||||
except KeyboardInterrupt:
|
||||
torch.save(net.state_dict(), 'INTERRUPTED.pth')
|
||||
print('Saved interrupt')
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0)
|
||||
|
|
|
@ -10,9 +10,11 @@ class double_conv(nn.Module):
|
|||
super(double_conv, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_ch, out_ch, 3, padding=1),
|
||||
nn.ReLU()
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
3
utils.py
3
utils.py
|
@ -13,7 +13,7 @@ def get_square(img, pos):
|
|||
return img[:, -h:]
|
||||
|
||||
|
||||
def resize_and_crop(pilimg, scale=0.2, final_height=None):
|
||||
def resize_and_crop(pilimg, scale=0.5, final_height=None):
|
||||
w = pilimg.size[0]
|
||||
h = pilimg.size[1]
|
||||
newW = int(w * scale)
|
||||
|
@ -46,6 +46,7 @@ def split_train_val(dataset, val_percent=0.05):
|
|||
dataset = list(dataset)
|
||||
length = len(dataset)
|
||||
n = int(length * val_percent)
|
||||
random.seed(42)
|
||||
random.shuffle(dataset)
|
||||
return {'train': dataset[:-n], 'val': dataset[-n:]}
|
||||
|
||||
|
|
Loading…
Reference in a new issue