Migration to PyTorch 0.4, code cleanup

Former-commit-id: c981801ccc3b74047e94c76e67c4ff1f3097226c
This commit is contained in:
milesial 2018-06-08 19:27:32 +02:00
parent 90e988c10f
commit 02e2314149
11 changed files with 214 additions and 204 deletions

View file

@ -1,17 +1,12 @@
#
# myloss.py : implementation of the Dice coeff and the associated loss
#
import torch
from torch.autograd import Function, Variable
class DiceCoeff(Function):
"""Dice coeff for individual examples"""
def forward(self, input, target):
self.save_for_backward(input, target)
self.inter = torch.dot(input, target) + 0.0001
self.inter = torch.dot(input.view(-1), target.view(-1)) + 0.0001
self.union = torch.sum(input) + torch.sum(target) + 0.0001
t = 2 * self.inter.float() / self.union.float()
@ -35,9 +30,9 @@ class DiceCoeff(Function):
def dice_coeff(input, target):
"""Dice coeff for batches"""
if input.is_cuda:
s = Variable(torch.FloatTensor(1).cuda().zero_())
s = torch.FloatTensor(1).cuda().zero_()
else:
s = Variable(torch.FloatTensor(1).zero_())
s = torch.FloatTensor(1).zero_()
for i, c in enumerate(zip(input, target)):
s = s + DiceCoeff().forward(c[0], c[1])

52
eval.py
View file

@ -1,55 +1,25 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from myloss import dice_coeff
from utils import dense_crf
from dice_loss import dice_coeff
def eval_net(net, dataset, gpu=False):
"""Evaluation without the densecrf with the dice coefficient"""
tot = 0
for i, b in enumerate(dataset):
X = b[0]
y = b[1]
img = b[0]
true_mask = b[1]
X = torch.FloatTensor(X).unsqueeze(0)
y = torch.ByteTensor(y).unsqueeze(0)
img = torch.from_numpy(img).unsqueeze(0)
true_mask = torch.from_numpy(true_mask).unsqueeze(0)
if gpu:
X = Variable(X, volatile=True).cuda()
y = Variable(y, volatile=True).cuda()
else:
X = Variable(X, volatile=True)
y = Variable(y, volatile=True)
img = img.cuda()
true_mask = true_mask.cuda()
y_pred = net(X)
mask_pred = net(img)[0]
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
y_pred = (F.sigmoid(y_pred) > 0.6).float()
# y_pred = F.sigmoid(y_pred).float()
dice = dice_coeff(y_pred, y.float()).data[0]
tot += dice
if 0:
X = X.data.squeeze(0).cpu().numpy()
X = np.transpose(X, axes=[1, 2, 0])
y = y.data.squeeze(0).cpu().numpy()
y_pred = y_pred.data.squeeze(0).squeeze(0).cpu().numpy()
print(y_pred.shape)
fig = plt.figure()
ax1 = fig.add_subplot(1, 4, 1)
ax1.imshow(X)
ax2 = fig.add_subplot(1, 4, 2)
ax2.imshow(y)
ax3 = fig.add_subplot(1, 4, 3)
ax3.imshow((y_pred > 0.5))
Q = dense_crf(((X * 255).round()).astype(np.uint8), y_pred)
ax4 = fig.add_subplot(1, 4, 4)
print(Q)
ax4.imshow(Q > 0.5)
plt.show()
tot += dice_coeff(mask_pred, true_mask).item()
return tot / i

View file

@ -1,48 +1,64 @@
import argparse
import os
import numpy
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from PIL import Image
from unet import UNet
from utils import *
from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf
from utils import plot_img_and_mask
def predict_img(net,
full_img,
scale_factor=0.5,
out_threshold=0.5,
use_dense_crf=True,
use_gpu=False):
img_height = full_img.size[1]
img_width = full_img.size[0]
img = resize_and_crop(full_img, scale=scale_factor)
img = normalize(img)
left_square, right_square = split_img_into_squares(img)
left_square = hwc_to_chw(left_square)
right_square = hwc_to_chw(right_square)
X_left = torch.from_numpy(left_square).unsqueeze(0)
X_right = torch.from_numpy(right_square).unsqueeze(0)
if use_gpu:
X_left = X_left.cuda()
X_right = X_right.cuda()
with torch.no_grad():
output_left = net(X_left)
output_right = net(X_right)
left_probs = F.sigmoid(output_left)
right_probs = F.sigmoid(output_right)
left_probs = F.upsample(left_probs, size=(img_height, img_height))
right_probs = F.upsample(right_probs, size=(img_height, img_height))
left_mask_np = left_probs.squeeze().cpu().numpy()
right_mask_np = right_probs.squeeze().cpu().numpy()
full_mask = merge_masks(left_mask_np, right_mask_np, img_width)
if use_dense_crf:
full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)
return full_mask > out_threshold
def predict_img(net, full_img, gpu=False):
img = resize_and_crop(full_img)
left = get_square(img, 0)
right = get_square(img, 1)
right = normalize(right)
left = normalize(left)
right = np.transpose(right, axes=[2, 0, 1])
left = np.transpose(left, axes=[2, 0, 1])
X_l = torch.FloatTensor(left).unsqueeze(0)
X_r = torch.FloatTensor(right).unsqueeze(0)
if gpu:
X_l = Variable(X_l, volatile=True).cuda()
X_r = Variable(X_r, volatile=True).cuda()
else:
X_l = Variable(X_l, volatile=True)
X_r = Variable(X_r, volatile=True)
y_l = F.sigmoid(net(X_l))
y_r = F.sigmoid(net(X_r))
y_l = F.upsample_bilinear(y_l, scale_factor=2).data[0][0].cpu().numpy()
y_r = F.upsample_bilinear(y_r, scale_factor=2).data[0][0].cpu().numpy()
y = merge_masks(y_l, y_r, full_img.size[0])
yy = dense_crf(np.array(full_img).astype(np.uint8), y)
return yy > 0.5
if __name__ == "__main__":
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', '-m', default='MODEL.pth',
metavar='FILE',
@ -61,19 +77,22 @@ if __name__ == "__main__":
parser.add_argument('--no-save', '-n', action='store_false',
help="Do not save the output masks",
default=False)
parser.add_argument('--no-crf', '-r', action='store_false',
help="Do not use dense CRF postprocessing",
default=False)
parser.add_argument('--mask-threshold', '-t', type=float,
help="Minimum probability value to consider a mask pixel white",
default=0.5)
parser.add_argument('--scale', '-s', type=float,
help="Scale factor for the input images",
default=0.5)
args = parser.parse_args()
print("Using model file : {}".format(args.model))
net = UNet(3, 1)
if not args.cpu:
print("Using CUDA version of the net, prepare your GPU !")
net.cuda()
else:
net.cpu()
print("Using CPU version of the net, this may be very slow")
return parser.parse_args()
def get_output_filenames(args):
in_files = args.input
out_files = []
if not args.output:
for f in in_files:
pathsplit = os.path.splitext(f)
@ -84,32 +103,52 @@ if __name__ == "__main__":
else:
out_files = args.output
print("Loading model ...")
net.load_state_dict(torch.load(args.model))
return out_files
def mask_to_image(mask):
return Image.fromarray((mask * 255).astype(np.uint8))
if __name__ == "__main__":
args = get_args()
in_files = args.input
out_files = get_output_filenames(args)
net = UNet(n_channels=3, n_classes=1)
print("Loading model {}".format(args.model))
if not args.cpu:
print("Using CUDA version of the net, prepare your GPU !")
net.cuda()
net.load_state_dict(torch.load(args.model))
else:
net.cpu()
net.load_state_dict(torch.load(args.model, map_location='cpu'))
print("Using CPU version of the net, this may be very slow")
print("Model loaded !")
for i, fn in enumerate(in_files):
print("\nPredicting image {} ...".format(fn))
img = Image.open(fn)
out = predict_img(net, img, not args.cpu)
if img.size[0] < img.size[1]:
print("Error: image height larger than the width")
mask = predict_img(net=net,
full_img=img,
scale_factor=args.scale,
out_threshold=args.mask_threshold,
use_dense_crf= not args.no_crf,
use_gpu=not args.cpu)
if args.viz:
print("Vizualising results for image {}, close to continue ..."
.format(fn))
fig = plt.figure()
a = fig.add_subplot(1, 2, 1)
a.set_title('Input image')
plt.imshow(img)
b = fig.add_subplot(1, 2, 2)
b.set_title('Output mask')
plt.imshow(out)
plt.show()
print("Visualizing results for image {}, close to continue ...".format(fn))
plot_img_and_mask(img, mask)
if not args.no_save:
out_fn = out_files[i]
result = Image.fromarray((out * 255).astype(numpy.uint8))
result = mask_to_image(mask)
result.save(out_files[i])
print("Mask saved to {}".format(out_files[i]))

View file

@ -1,10 +1,15 @@
# used to predict all test images and encode results in a csv file
import os
from PIL import Image
from predict import *
import torch
from predict import predict_img
from utils import rle_encode
from unet import UNet
def submit(net, gpu=False):
"""Used for Kaggle submission: predicts and encode all test images"""
dir = 'data/test/'
N = len(list(os.listdir(dir)))

120
train.py
View file

@ -1,20 +1,27 @@
import sys
import os
from optparse import OptionParser
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.autograd import Variable
from eval import eval_net
from unet import UNet
from utils import *
from utils import get_ids, split_ids, split_train_val, get_imgs_and_masks, batch
def train_net(net,
epochs=5,
batch_size=1,
lr=0.1,
val_percent=0.05,
save_cp=True,
gpu=False,
img_scale=0.5):
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
cp=True, gpu=False):
dir_img = 'data/train/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoints/'
@ -34,69 +41,66 @@ def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
Checkpoints: {}
CUDA: {}
'''.format(epochs, batch_size, lr, len(iddataset['train']),
len(iddataset['val']), str(cp), str(gpu)))
len(iddataset['val']), str(save_cp), str(gpu)))
N_train = len(iddataset['train'])
optimizer = optim.SGD(net.parameters(),
lr=lr, momentum=0.9, weight_decay=0.0005)
lr=lr,
momentum=0.9,
weight_decay=0.0005)
criterion = nn.BCELoss()
for epoch in range(epochs):
print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
# reset the generators
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask)
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale)
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale)
epoch_loss = 0
for i, b in enumerate(batch(train, batch_size)):
imgs = np.array([i[0] for i in b]).astype(np.float32)
true_masks = np.array([i[1] for i in b])
imgs = torch.from_numpy(imgs)
true_masks = torch.from_numpy(true_masks)
if gpu:
imgs = imgs.cuda()
true_masks = true_masks.cuda()
masks_pred = net(imgs)
masks_probs = F.sigmoid(masks_pred)
masks_probs_flat = masks_probs.view(-1)
true_masks_flat = true_masks.view(-1)
loss = criterion(masks_probs_flat, true_masks_flat)
epoch_loss += loss.item()
print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch finished ! Loss: {}'.format(epoch_loss / i))
if 1:
val_dice = eval_net(net, val, gpu)
print('Validation Dice Coeff: {}'.format(val_dice))
for i, b in enumerate(batch(train, batch_size)):
X = np.array([i[0] for i in b])
y = np.array([i[1] for i in b])
X = torch.FloatTensor(X)
y = torch.ByteTensor(y)
if gpu:
X = Variable(X).cuda()
y = Variable(y).cuda()
else:
X = Variable(X)
y = Variable(y)
y_pred = net(X)
probs = F.sigmoid(y_pred)
probs_flat = probs.view(-1)
y_flat = y.view(-1)
loss = criterion(probs_flat, y_flat.float())
epoch_loss += loss.data[0]
print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train,
loss.data[0]))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch finished ! Loss: {}'.format(epoch_loss / i))
if cp:
if save_cp:
torch.save(net.state_dict(),
dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
print('Checkpoint {} saved !'.format(epoch + 1))
if __name__ == '__main__':
def get_args():
parser = OptionParser()
parser.add_option('-e', '--epochs', dest='epochs', default=5, type='int',
help='number of epochs')
@ -108,22 +112,32 @@ if __name__ == '__main__':
default=False, help='use cuda')
parser.add_option('-c', '--load', dest='load',
default=False, help='load file model')
parser.add_option('-s', '--scale', dest='scale', type='float',
default=0.5, help='downscaling factor of the images')
(options, args) = parser.parse_args()
return options
net = UNet(3, 1)
if __name__ == '__main__':
args = get_args()
if options.load:
net.load_state_dict(torch.load(options.load))
print('Model loaded from {}'.format(options.load))
net = UNet(n_channels=3, n_classes=1)
if options.gpu:
if args.load:
net.load_state_dict(torch.load(args.load))
print('Model loaded from {}'.format(args.load))
if args.gpu:
net.cuda()
cudnn.benchmark = True
# cudnn.benchmark = True # faster convolutions, but more memory
try:
train_net(net, options.epochs, options.batchsize, options.lr,
gpu=options.gpu)
train_net(net=net,
epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
gpu=args.gpu,
img_scale=args.scale)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
print('Saved interrupt')

View file

@ -1,14 +1,7 @@
#!/usr/bin/python
# full assembly of the sub-parts to form the complete net
import torch
import torch.nn as nn
import torch.nn.functional as F
# python 3 confusing imports :(
from .unet_parts import *
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()

View file

@ -1,5 +1,3 @@
#!/usr/bin/python
# sub-parts of the U-Net model
import torch
@ -53,9 +51,9 @@ class up(nn.Module):
super(up, self).__init__()
# would be a nice idea if the upsampling could be learned too,
#  but my machine do not have enough memory to handle all those weights
# but my machine do not have enough memory to handle all those weights
if bilinear:
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

View file

@ -1,7 +1,6 @@
import numpy as np
import pydensecrf.densecrf as dcrf
def dense_crf(img, output_probs):
h = output_probs.shape[0]
w = output_probs.shape[1]

View file

@ -1,13 +1,12 @@
import matplotlib.pyplot as plt
def plot_img_mask(img, mask):
def plot_img_and_mask(img, mask):
fig = plt.figure()
a = fig.add_subplot(1, 2, 1)
a.set_title('Input image')
plt.imshow(img)
ax1 = fig.add_subplot(1, 3, 1)
ax1.imshow(img)
ax2 = fig.add_subplot(1, 3, 2)
ax2.imshow(mask)
plt.show()
b = fig.add_subplot(1, 2, 2)
b.set_title('Output mask')
plt.imshow(mask)
plt.show()

View file

@ -3,12 +3,11 @@
# cropped images and masks
import os
from functools import partial
import numpy as np
from PIL import Image
from .utils import resize_and_crop, get_square, normalize
from .utils import resize_and_crop, get_square, normalize, hwc_to_chw
def get_ids(dir):
@ -21,23 +20,22 @@ def split_ids(ids, n=2):
return ((id, i) for i in range(n) for id in ids)
def to_cropped_imgs(ids, dir, suffix):
def to_cropped_imgs(ids, dir, suffix, scale):
"""From a list of tuples, returns the correct cropped img"""
for id, pos in ids:
im = resize_and_crop(Image.open(dir + id + suffix))
im = resize_and_crop(Image.open(dir + id + suffix), scale=scale)
yield get_square(im, pos)
def get_imgs_and_masks(ids, dir_img, dir_mask):
def get_imgs_and_masks(ids, dir_img, dir_mask, scale):
"""Return all the couples (img, mask)"""
imgs = to_cropped_imgs(ids, dir_img, '.jpg')
imgs = to_cropped_imgs(ids, dir_img, '.jpg', scale)
# need to transform from HWC to CHW
imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs)
imgs_switched = map(hwc_to_chw, imgs)
imgs_normalized = map(normalize, imgs_switched)
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif')
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif', scale)
return zip(imgs_normalized, masks)

View file

@ -1,17 +1,20 @@
import random
import numpy as np
def get_square(img, pos):
"""Extract a left or a right square from PILimg shape : (H, W, C))"""
img = np.array(img)
"""Extract a left or a right square from ndarray shape : (H, W, C))"""
h = img.shape[0]
if pos == 0:
return img[:, :h]
else:
return img[:, -h:]
def split_img_into_squares(img):
return get_square(img, 0), get_square(img, 1)
def hwc_to_chw(img):
return np.transpose(img, axes=[2, 0, 1])
def resize_and_crop(pilimg, scale=0.5, final_height=None):
w = pilimg.size[0]
@ -26,8 +29,7 @@ def resize_and_crop(pilimg, scale=0.5, final_height=None):
img = pilimg.resize((newW, newH))
img = img.crop((0, diff // 2, newW, newH - diff // 2))
return img
return np.array(img, dtype=np.float32)
def batch(iterable, batch_size):
"""Yields lists by batch"""
@ -41,7 +43,6 @@ def batch(iterable, batch_size):
if len(b) > 0:
yield b
def split_train_val(dataset, val_percent=0.05):
dataset = list(dataset)
length = len(dataset)
@ -53,18 +54,17 @@ def split_train_val(dataset, val_percent=0.05):
def normalize(x):
return x / 255
def merge_masks(img1, img2, full_w):
h = img1.shape[0]
new = np.zeros((h, full_w), np.float32)
new[:, :full_w // 2 + 1] = img1[:, :full_w // 2 + 1]
new[:, full_w // 2 + 1:] = img2[:, -(full_w // 2 - 1):]
return new
# credits to https://stackoverflow.com/users/6076729/manuel-lagunas
def rle_encode(mask_image):
pixels = mask_image.flatten()
# We avoid issues with '1' at the start or end (at the corners of