Migration to PyTorch 0.4, code cleanup

Former-commit-id: c981801ccc3b74047e94c76e67c4ff1f3097226c
This commit is contained in:
milesial 2018-06-08 19:27:32 +02:00
parent 90e988c10f
commit 02e2314149
11 changed files with 214 additions and 204 deletions

View file

@ -1,17 +1,12 @@
#
# myloss.py : implementation of the Dice coeff and the associated loss
#
import torch import torch
from torch.autograd import Function, Variable from torch.autograd import Function, Variable
class DiceCoeff(Function): class DiceCoeff(Function):
"""Dice coeff for individual examples""" """Dice coeff for individual examples"""
def forward(self, input, target): def forward(self, input, target):
self.save_for_backward(input, target) self.save_for_backward(input, target)
self.inter = torch.dot(input, target) + 0.0001 self.inter = torch.dot(input.view(-1), target.view(-1)) + 0.0001
self.union = torch.sum(input) + torch.sum(target) + 0.0001 self.union = torch.sum(input) + torch.sum(target) + 0.0001
t = 2 * self.inter.float() / self.union.float() t = 2 * self.inter.float() / self.union.float()
@ -35,9 +30,9 @@ class DiceCoeff(Function):
def dice_coeff(input, target): def dice_coeff(input, target):
"""Dice coeff for batches""" """Dice coeff for batches"""
if input.is_cuda: if input.is_cuda:
s = Variable(torch.FloatTensor(1).cuda().zero_()) s = torch.FloatTensor(1).cuda().zero_()
else: else:
s = Variable(torch.FloatTensor(1).zero_()) s = torch.FloatTensor(1).zero_()
for i, c in enumerate(zip(input, target)): for i, c in enumerate(zip(input, target)):
s = s + DiceCoeff().forward(c[0], c[1]) s = s + DiceCoeff().forward(c[0], c[1])

52
eval.py
View file

@ -1,55 +1,25 @@
import matplotlib.pyplot as plt
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable
from myloss import dice_coeff from dice_loss import dice_coeff
from utils import dense_crf
def eval_net(net, dataset, gpu=False): def eval_net(net, dataset, gpu=False):
"""Evaluation without the densecrf with the dice coefficient"""
tot = 0 tot = 0
for i, b in enumerate(dataset): for i, b in enumerate(dataset):
X = b[0] img = b[0]
y = b[1] true_mask = b[1]
X = torch.FloatTensor(X).unsqueeze(0) img = torch.from_numpy(img).unsqueeze(0)
y = torch.ByteTensor(y).unsqueeze(0) true_mask = torch.from_numpy(true_mask).unsqueeze(0)
if gpu: if gpu:
X = Variable(X, volatile=True).cuda() img = img.cuda()
y = Variable(y, volatile=True).cuda() true_mask = true_mask.cuda()
else:
X = Variable(X, volatile=True)
y = Variable(y, volatile=True)
y_pred = net(X) mask_pred = net(img)[0]
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
y_pred = (F.sigmoid(y_pred) > 0.6).float() tot += dice_coeff(mask_pred, true_mask).item()
# y_pred = F.sigmoid(y_pred).float()
dice = dice_coeff(y_pred, y.float()).data[0]
tot += dice
if 0:
X = X.data.squeeze(0).cpu().numpy()
X = np.transpose(X, axes=[1, 2, 0])
y = y.data.squeeze(0).cpu().numpy()
y_pred = y_pred.data.squeeze(0).squeeze(0).cpu().numpy()
print(y_pred.shape)
fig = plt.figure()
ax1 = fig.add_subplot(1, 4, 1)
ax1.imshow(X)
ax2 = fig.add_subplot(1, 4, 2)
ax2.imshow(y)
ax3 = fig.add_subplot(1, 4, 3)
ax3.imshow((y_pred > 0.5))
Q = dense_crf(((X * 255).round()).astype(np.uint8), y_pred)
ax4 = fig.add_subplot(1, 4, 4)
print(Q)
ax4.imshow(Q > 0.5)
plt.show()
return tot / i return tot / i

View file

@ -1,48 +1,64 @@
import argparse import argparse
import os
import numpy import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable
from PIL import Image
from unet import UNet from unet import UNet
from utils import * from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf
from utils import plot_img_and_mask
def predict_img(net,
full_img,
scale_factor=0.5,
out_threshold=0.5,
use_dense_crf=True,
use_gpu=False):
img_height = full_img.size[1]
img_width = full_img.size[0]
img = resize_and_crop(full_img, scale=scale_factor)
img = normalize(img)
left_square, right_square = split_img_into_squares(img)
left_square = hwc_to_chw(left_square)
right_square = hwc_to_chw(right_square)
X_left = torch.from_numpy(left_square).unsqueeze(0)
X_right = torch.from_numpy(right_square).unsqueeze(0)
if use_gpu:
X_left = X_left.cuda()
X_right = X_right.cuda()
with torch.no_grad():
output_left = net(X_left)
output_right = net(X_right)
left_probs = F.sigmoid(output_left)
right_probs = F.sigmoid(output_right)
left_probs = F.upsample(left_probs, size=(img_height, img_height))
right_probs = F.upsample(right_probs, size=(img_height, img_height))
left_mask_np = left_probs.squeeze().cpu().numpy()
right_mask_np = right_probs.squeeze().cpu().numpy()
full_mask = merge_masks(left_mask_np, right_mask_np, img_width)
if use_dense_crf:
full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)
return full_mask > out_threshold
def predict_img(net, full_img, gpu=False):
img = resize_and_crop(full_img)
left = get_square(img, 0) def get_args():
right = get_square(img, 1)
right = normalize(right)
left = normalize(left)
right = np.transpose(right, axes=[2, 0, 1])
left = np.transpose(left, axes=[2, 0, 1])
X_l = torch.FloatTensor(left).unsqueeze(0)
X_r = torch.FloatTensor(right).unsqueeze(0)
if gpu:
X_l = Variable(X_l, volatile=True).cuda()
X_r = Variable(X_r, volatile=True).cuda()
else:
X_l = Variable(X_l, volatile=True)
X_r = Variable(X_r, volatile=True)
y_l = F.sigmoid(net(X_l))
y_r = F.sigmoid(net(X_r))
y_l = F.upsample_bilinear(y_l, scale_factor=2).data[0][0].cpu().numpy()
y_r = F.upsample_bilinear(y_r, scale_factor=2).data[0][0].cpu().numpy()
y = merge_masks(y_l, y_r, full_img.size[0])
yy = dense_crf(np.array(full_img).astype(np.uint8), y)
return yy > 0.5
if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', '-m', default='MODEL.pth', parser.add_argument('--model', '-m', default='MODEL.pth',
metavar='FILE', metavar='FILE',
@ -61,19 +77,22 @@ if __name__ == "__main__":
parser.add_argument('--no-save', '-n', action='store_false', parser.add_argument('--no-save', '-n', action='store_false',
help="Do not save the output masks", help="Do not save the output masks",
default=False) default=False)
parser.add_argument('--no-crf', '-r', action='store_false',
help="Do not use dense CRF postprocessing",
default=False)
parser.add_argument('--mask-threshold', '-t', type=float,
help="Minimum probability value to consider a mask pixel white",
default=0.5)
parser.add_argument('--scale', '-s', type=float,
help="Scale factor for the input images",
default=0.5)
args = parser.parse_args() return parser.parse_args()
print("Using model file : {}".format(args.model))
net = UNet(3, 1)
if not args.cpu:
print("Using CUDA version of the net, prepare your GPU !")
net.cuda()
else:
net.cpu()
print("Using CPU version of the net, this may be very slow")
def get_output_filenames(args):
in_files = args.input in_files = args.input
out_files = [] out_files = []
if not args.output: if not args.output:
for f in in_files: for f in in_files:
pathsplit = os.path.splitext(f) pathsplit = os.path.splitext(f)
@ -84,32 +103,52 @@ if __name__ == "__main__":
else: else:
out_files = args.output out_files = args.output
print("Loading model ...") return out_files
net.load_state_dict(torch.load(args.model))
def mask_to_image(mask):
return Image.fromarray((mask * 255).astype(np.uint8))
if __name__ == "__main__":
args = get_args()
in_files = args.input
out_files = get_output_filenames(args)
net = UNet(n_channels=3, n_classes=1)
print("Loading model {}".format(args.model))
if not args.cpu:
print("Using CUDA version of the net, prepare your GPU !")
net.cuda()
net.load_state_dict(torch.load(args.model))
else:
net.cpu()
net.load_state_dict(torch.load(args.model, map_location='cpu'))
print("Using CPU version of the net, this may be very slow")
print("Model loaded !") print("Model loaded !")
for i, fn in enumerate(in_files): for i, fn in enumerate(in_files):
print("\nPredicting image {} ...".format(fn)) print("\nPredicting image {} ...".format(fn))
img = Image.open(fn) img = Image.open(fn)
out = predict_img(net, img, not args.cpu) if img.size[0] < img.size[1]:
print("Error: image height larger than the width")
mask = predict_img(net=net,
full_img=img,
scale_factor=args.scale,
out_threshold=args.mask_threshold,
use_dense_crf= not args.no_crf,
use_gpu=not args.cpu)
if args.viz: if args.viz:
print("Vizualising results for image {}, close to continue ..." print("Visualizing results for image {}, close to continue ...".format(fn))
.format(fn)) plot_img_and_mask(img, mask)
fig = plt.figure()
a = fig.add_subplot(1, 2, 1)
a.set_title('Input image')
plt.imshow(img)
b = fig.add_subplot(1, 2, 2)
b.set_title('Output mask')
plt.imshow(out)
plt.show()
if not args.no_save: if not args.no_save:
out_fn = out_files[i] out_fn = out_files[i]
result = Image.fromarray((out * 255).astype(numpy.uint8)) result = mask_to_image(mask)
result.save(out_files[i]) result.save(out_files[i])
print("Mask saved to {}".format(out_files[i])) print("Mask saved to {}".format(out_files[i]))

View file

@ -1,10 +1,15 @@
# used to predict all test images and encode results in a csv file import os
from PIL import Image
from predict import * import torch
from predict import predict_img
from utils import rle_encode
from unet import UNet from unet import UNet
def submit(net, gpu=False): def submit(net, gpu=False):
"""Used for Kaggle submission: predicts and encode all test images"""
dir = 'data/test/' dir = 'data/test/'
N = len(list(os.listdir(dir))) N = len(list(os.listdir(dir)))

120
train.py
View file

@ -1,20 +1,27 @@
import sys import sys
import os
from optparse import OptionParser from optparse import OptionParser
import numpy as np
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import optim from torch import optim
from torch.autograd import Variable
from eval import eval_net from eval import eval_net
from unet import UNet from unet import UNet
from utils import * from utils import get_ids, split_ids, split_train_val, get_imgs_and_masks, batch
def train_net(net,
epochs=5,
batch_size=1,
lr=0.1,
val_percent=0.05,
save_cp=True,
gpu=False,
img_scale=0.5):
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
cp=True, gpu=False):
dir_img = 'data/train/' dir_img = 'data/train/'
dir_mask = 'data/train_masks/' dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoints/' dir_checkpoint = 'checkpoints/'
@ -34,69 +41,66 @@ def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
Checkpoints: {} Checkpoints: {}
CUDA: {} CUDA: {}
'''.format(epochs, batch_size, lr, len(iddataset['train']), '''.format(epochs, batch_size, lr, len(iddataset['train']),
len(iddataset['val']), str(cp), str(gpu))) len(iddataset['val']), str(save_cp), str(gpu)))
N_train = len(iddataset['train']) N_train = len(iddataset['train'])
optimizer = optim.SGD(net.parameters(), optimizer = optim.SGD(net.parameters(),
lr=lr, momentum=0.9, weight_decay=0.0005) lr=lr,
momentum=0.9,
weight_decay=0.0005)
criterion = nn.BCELoss() criterion = nn.BCELoss()
for epoch in range(epochs): for epoch in range(epochs):
print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
# reset the generators # reset the generators
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask) train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale)
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask) val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale)
epoch_loss = 0 epoch_loss = 0
for i, b in enumerate(batch(train, batch_size)):
imgs = np.array([i[0] for i in b]).astype(np.float32)
true_masks = np.array([i[1] for i in b])
imgs = torch.from_numpy(imgs)
true_masks = torch.from_numpy(true_masks)
if gpu:
imgs = imgs.cuda()
true_masks = true_masks.cuda()
masks_pred = net(imgs)
masks_probs = F.sigmoid(masks_pred)
masks_probs_flat = masks_probs.view(-1)
true_masks_flat = true_masks.view(-1)
loss = criterion(masks_probs_flat, true_masks_flat)
epoch_loss += loss.item()
print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch finished ! Loss: {}'.format(epoch_loss / i))
if 1: if 1:
val_dice = eval_net(net, val, gpu) val_dice = eval_net(net, val, gpu)
print('Validation Dice Coeff: {}'.format(val_dice)) print('Validation Dice Coeff: {}'.format(val_dice))
for i, b in enumerate(batch(train, batch_size)): if save_cp:
X = np.array([i[0] for i in b])
y = np.array([i[1] for i in b])
X = torch.FloatTensor(X)
y = torch.ByteTensor(y)
if gpu:
X = Variable(X).cuda()
y = Variable(y).cuda()
else:
X = Variable(X)
y = Variable(y)
y_pred = net(X)
probs = F.sigmoid(y_pred)
probs_flat = probs.view(-1)
y_flat = y.view(-1)
loss = criterion(probs_flat, y_flat.float())
epoch_loss += loss.data[0]
print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train,
loss.data[0]))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch finished ! Loss: {}'.format(epoch_loss / i))
if cp:
torch.save(net.state_dict(), torch.save(net.state_dict(),
dir_checkpoint + 'CP{}.pth'.format(epoch + 1)) dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
print('Checkpoint {} saved !'.format(epoch + 1)) print('Checkpoint {} saved !'.format(epoch + 1))
if __name__ == '__main__':
def get_args():
parser = OptionParser() parser = OptionParser()
parser.add_option('-e', '--epochs', dest='epochs', default=5, type='int', parser.add_option('-e', '--epochs', dest='epochs', default=5, type='int',
help='number of epochs') help='number of epochs')
@ -108,22 +112,32 @@ if __name__ == '__main__':
default=False, help='use cuda') default=False, help='use cuda')
parser.add_option('-c', '--load', dest='load', parser.add_option('-c', '--load', dest='load',
default=False, help='load file model') default=False, help='load file model')
parser.add_option('-s', '--scale', dest='scale', type='float',
default=0.5, help='downscaling factor of the images')
(options, args) = parser.parse_args() (options, args) = parser.parse_args()
return options
net = UNet(3, 1) if __name__ == '__main__':
args = get_args()
if options.load: net = UNet(n_channels=3, n_classes=1)
net.load_state_dict(torch.load(options.load))
print('Model loaded from {}'.format(options.load))
if options.gpu: if args.load:
net.load_state_dict(torch.load(args.load))
print('Model loaded from {}'.format(args.load))
if args.gpu:
net.cuda() net.cuda()
cudnn.benchmark = True # cudnn.benchmark = True # faster convolutions, but more memory
try: try:
train_net(net, options.epochs, options.batchsize, options.lr, train_net(net=net,
gpu=options.gpu) epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
gpu=args.gpu,
img_scale=args.scale)
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth') torch.save(net.state_dict(), 'INTERRUPTED.pth')
print('Saved interrupt') print('Saved interrupt')

View file

@ -1,14 +1,7 @@
#!/usr/bin/python
# full assembly of the sub-parts to form the complete net # full assembly of the sub-parts to form the complete net
import torch
import torch.nn as nn
import torch.nn.functional as F
# python 3 confusing imports :(
from .unet_parts import * from .unet_parts import *
class UNet(nn.Module): class UNet(nn.Module):
def __init__(self, n_channels, n_classes): def __init__(self, n_channels, n_classes):
super(UNet, self).__init__() super(UNet, self).__init__()

View file

@ -1,5 +1,3 @@
#!/usr/bin/python
# sub-parts of the U-Net model # sub-parts of the U-Net model
import torch import torch
@ -53,9 +51,9 @@ class up(nn.Module):
super(up, self).__init__() super(up, self).__init__()
# would be a nice idea if the upsampling could be learned too, # would be a nice idea if the upsampling could be learned too,
#  but my machine do not have enough memory to handle all those weights # but my machine do not have enough memory to handle all those weights
if bilinear: if bilinear:
self.up = nn.UpsamplingBilinear2d(scale_factor=2) self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else: else:
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

View file

@ -1,7 +1,6 @@
import numpy as np import numpy as np
import pydensecrf.densecrf as dcrf import pydensecrf.densecrf as dcrf
def dense_crf(img, output_probs): def dense_crf(img, output_probs):
h = output_probs.shape[0] h = output_probs.shape[0]
w = output_probs.shape[1] w = output_probs.shape[1]

View file

@ -1,13 +1,12 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def plot_img_and_mask(img, mask):
def plot_img_mask(img, mask):
fig = plt.figure() fig = plt.figure()
a = fig.add_subplot(1, 2, 1)
a.set_title('Input image')
plt.imshow(img)
ax1 = fig.add_subplot(1, 3, 1) b = fig.add_subplot(1, 2, 2)
ax1.imshow(img) b.set_title('Output mask')
plt.imshow(mask)
ax2 = fig.add_subplot(1, 3, 2) plt.show()
ax2.imshow(mask)
plt.show()

View file

@ -3,12 +3,11 @@
# cropped images and masks # cropped images and masks
import os import os
from functools import partial
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from .utils import resize_and_crop, get_square, normalize from .utils import resize_and_crop, get_square, normalize, hwc_to_chw
def get_ids(dir): def get_ids(dir):
@ -21,23 +20,22 @@ def split_ids(ids, n=2):
return ((id, i) for i in range(n) for id in ids) return ((id, i) for i in range(n) for id in ids)
def to_cropped_imgs(ids, dir, suffix): def to_cropped_imgs(ids, dir, suffix, scale):
"""From a list of tuples, returns the correct cropped img""" """From a list of tuples, returns the correct cropped img"""
for id, pos in ids: for id, pos in ids:
im = resize_and_crop(Image.open(dir + id + suffix)) im = resize_and_crop(Image.open(dir + id + suffix), scale=scale)
yield get_square(im, pos) yield get_square(im, pos)
def get_imgs_and_masks(ids, dir_img, dir_mask, scale):
def get_imgs_and_masks(ids, dir_img, dir_mask):
"""Return all the couples (img, mask)""" """Return all the couples (img, mask)"""
imgs = to_cropped_imgs(ids, dir_img, '.jpg') imgs = to_cropped_imgs(ids, dir_img, '.jpg', scale)
# need to transform from HWC to CHW # need to transform from HWC to CHW
imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs) imgs_switched = map(hwc_to_chw, imgs)
imgs_normalized = map(normalize, imgs_switched) imgs_normalized = map(normalize, imgs_switched)
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif') masks = to_cropped_imgs(ids, dir_mask, '_mask.gif', scale)
return zip(imgs_normalized, masks) return zip(imgs_normalized, masks)

View file

@ -1,17 +1,20 @@
import random import random
import numpy as np import numpy as np
def get_square(img, pos): def get_square(img, pos):
"""Extract a left or a right square from PILimg shape : (H, W, C))""" """Extract a left or a right square from ndarray shape : (H, W, C))"""
img = np.array(img)
h = img.shape[0] h = img.shape[0]
if pos == 0: if pos == 0:
return img[:, :h] return img[:, :h]
else: else:
return img[:, -h:] return img[:, -h:]
def split_img_into_squares(img):
return get_square(img, 0), get_square(img, 1)
def hwc_to_chw(img):
return np.transpose(img, axes=[2, 0, 1])
def resize_and_crop(pilimg, scale=0.5, final_height=None): def resize_and_crop(pilimg, scale=0.5, final_height=None):
w = pilimg.size[0] w = pilimg.size[0]
@ -26,8 +29,7 @@ def resize_and_crop(pilimg, scale=0.5, final_height=None):
img = pilimg.resize((newW, newH)) img = pilimg.resize((newW, newH))
img = img.crop((0, diff // 2, newW, newH - diff // 2)) img = img.crop((0, diff // 2, newW, newH - diff // 2))
return img return np.array(img, dtype=np.float32)
def batch(iterable, batch_size): def batch(iterable, batch_size):
"""Yields lists by batch""" """Yields lists by batch"""
@ -41,7 +43,6 @@ def batch(iterable, batch_size):
if len(b) > 0: if len(b) > 0:
yield b yield b
def split_train_val(dataset, val_percent=0.05): def split_train_val(dataset, val_percent=0.05):
dataset = list(dataset) dataset = list(dataset)
length = len(dataset) length = len(dataset)
@ -53,18 +54,17 @@ def split_train_val(dataset, val_percent=0.05):
def normalize(x): def normalize(x):
return x / 255 return x / 255
def merge_masks(img1, img2, full_w): def merge_masks(img1, img2, full_w):
h = img1.shape[0] h = img1.shape[0]
new = np.zeros((h, full_w), np.float32) new = np.zeros((h, full_w), np.float32)
new[:, :full_w // 2 + 1] = img1[:, :full_w // 2 + 1] new[:, :full_w // 2 + 1] = img1[:, :full_w // 2 + 1]
new[:, full_w // 2 + 1:] = img2[:, -(full_w // 2 - 1):] new[:, full_w // 2 + 1:] = img2[:, -(full_w // 2 - 1):]
return new return new
# credits to https://stackoverflow.com/users/6076729/manuel-lagunas
def rle_encode(mask_image): def rle_encode(mask_image):
pixels = mask_image.flatten() pixels = mask_image.flatten()
# We avoid issues with '1' at the start or end (at the corners of # We avoid issues with '1' at the start or end (at the corners of