From 02e231414906f995b88a352f8ec0b0cbbc5199d3 Mon Sep 17 00:00:00 2001 From: milesial Date: Fri, 8 Jun 2018 19:27:32 +0200 Subject: [PATCH] Migration to PyTorch 0.4, code cleanup Former-commit-id: c981801ccc3b74047e94c76e67c4ff1f3097226c --- myloss.py => dice_loss.py | 11 +-- eval.py | 52 +++--------- predict.py | 163 +++++++++++++++++++++++--------------- submit.py | 9 ++- train.py | 120 +++++++++++++++------------- unet/unet_model.py | 7 -- unet/unet_parts.py | 6 +- utils/crf.py | 1 - utils/data_vis.py | 17 ++-- utils/load.py | 16 ++-- utils/utils.py | 16 ++-- 11 files changed, 214 insertions(+), 204 deletions(-) rename myloss.py => dice_loss.py (81%) diff --git a/myloss.py b/dice_loss.py similarity index 81% rename from myloss.py rename to dice_loss.py index e28a10d..29a287d 100644 --- a/myloss.py +++ b/dice_loss.py @@ -1,17 +1,12 @@ -# -# myloss.py : implementation of the Dice coeff and the associated loss -# - import torch from torch.autograd import Function, Variable - class DiceCoeff(Function): """Dice coeff for individual examples""" def forward(self, input, target): self.save_for_backward(input, target) - self.inter = torch.dot(input, target) + 0.0001 + self.inter = torch.dot(input.view(-1), target.view(-1)) + 0.0001 self.union = torch.sum(input) + torch.sum(target) + 0.0001 t = 2 * self.inter.float() / self.union.float() @@ -35,9 +30,9 @@ class DiceCoeff(Function): def dice_coeff(input, target): """Dice coeff for batches""" if input.is_cuda: - s = Variable(torch.FloatTensor(1).cuda().zero_()) + s = torch.FloatTensor(1).cuda().zero_() else: - s = Variable(torch.FloatTensor(1).zero_()) + s = torch.FloatTensor(1).zero_() for i, c in enumerate(zip(input, target)): s = s + DiceCoeff().forward(c[0], c[1]) diff --git a/eval.py b/eval.py index de56801..944c111 100644 --- a/eval.py +++ b/eval.py @@ -1,55 +1,25 @@ -import matplotlib.pyplot as plt -import numpy as np import torch import torch.nn.functional as F -from torch.autograd import Variable -from myloss import dice_coeff -from utils import dense_crf +from dice_loss import dice_coeff def eval_net(net, dataset, gpu=False): + """Evaluation without the densecrf with the dice coefficient""" tot = 0 for i, b in enumerate(dataset): - X = b[0] - y = b[1] + img = b[0] + true_mask = b[1] - X = torch.FloatTensor(X).unsqueeze(0) - y = torch.ByteTensor(y).unsqueeze(0) + img = torch.from_numpy(img).unsqueeze(0) + true_mask = torch.from_numpy(true_mask).unsqueeze(0) if gpu: - X = Variable(X, volatile=True).cuda() - y = Variable(y, volatile=True).cuda() - else: - X = Variable(X, volatile=True) - y = Variable(y, volatile=True) + img = img.cuda() + true_mask = true_mask.cuda() - y_pred = net(X) + mask_pred = net(img)[0] + mask_pred = (F.sigmoid(mask_pred) > 0.5).float() - y_pred = (F.sigmoid(y_pred) > 0.6).float() - # y_pred = F.sigmoid(y_pred).float() - - dice = dice_coeff(y_pred, y.float()).data[0] - tot += dice - - if 0: - X = X.data.squeeze(0).cpu().numpy() - X = np.transpose(X, axes=[1, 2, 0]) - y = y.data.squeeze(0).cpu().numpy() - y_pred = y_pred.data.squeeze(0).squeeze(0).cpu().numpy() - print(y_pred.shape) - - fig = plt.figure() - ax1 = fig.add_subplot(1, 4, 1) - ax1.imshow(X) - ax2 = fig.add_subplot(1, 4, 2) - ax2.imshow(y) - ax3 = fig.add_subplot(1, 4, 3) - ax3.imshow((y_pred > 0.5)) - - Q = dense_crf(((X * 255).round()).astype(np.uint8), y_pred) - ax4 = fig.add_subplot(1, 4, 4) - print(Q) - ax4.imshow(Q > 0.5) - plt.show() + tot += dice_coeff(mask_pred, true_mask).item() return tot / i diff --git a/predict.py b/predict.py index 45d442e..d963b04 100644 --- a/predict.py +++ b/predict.py @@ -1,48 +1,64 @@ import argparse +import os -import numpy +import numpy as np import torch import torch.nn.functional as F -from torch.autograd import Variable + +from PIL import Image from unet import UNet -from utils import * +from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf +from utils import plot_img_and_mask + +def predict_img(net, + full_img, + scale_factor=0.5, + out_threshold=0.5, + use_dense_crf=True, + use_gpu=False): + + img_height = full_img.size[1] + img_width = full_img.size[0] + + img = resize_and_crop(full_img, scale=scale_factor) + img = normalize(img) + + left_square, right_square = split_img_into_squares(img) + + left_square = hwc_to_chw(left_square) + right_square = hwc_to_chw(right_square) + + X_left = torch.from_numpy(left_square).unsqueeze(0) + X_right = torch.from_numpy(right_square).unsqueeze(0) + + if use_gpu: + X_left = X_left.cuda() + X_right = X_right.cuda() + + with torch.no_grad(): + output_left = net(X_left) + output_right = net(X_right) + + left_probs = F.sigmoid(output_left) + right_probs = F.sigmoid(output_right) + + left_probs = F.upsample(left_probs, size=(img_height, img_height)) + right_probs = F.upsample(right_probs, size=(img_height, img_height)) + + left_mask_np = left_probs.squeeze().cpu().numpy() + right_mask_np = right_probs.squeeze().cpu().numpy() + + full_mask = merge_masks(left_mask_np, right_mask_np, img_width) + + if use_dense_crf: + full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask) + + return full_mask > out_threshold -def predict_img(net, full_img, gpu=False): - img = resize_and_crop(full_img) - left = get_square(img, 0) - right = get_square(img, 1) - - right = normalize(right) - left = normalize(left) - - right = np.transpose(right, axes=[2, 0, 1]) - left = np.transpose(left, axes=[2, 0, 1]) - - X_l = torch.FloatTensor(left).unsqueeze(0) - X_r = torch.FloatTensor(right).unsqueeze(0) - - if gpu: - X_l = Variable(X_l, volatile=True).cuda() - X_r = Variable(X_r, volatile=True).cuda() - else: - X_l = Variable(X_l, volatile=True) - X_r = Variable(X_r, volatile=True) - - y_l = F.sigmoid(net(X_l)) - y_r = F.sigmoid(net(X_r)) - y_l = F.upsample_bilinear(y_l, scale_factor=2).data[0][0].cpu().numpy() - y_r = F.upsample_bilinear(y_r, scale_factor=2).data[0][0].cpu().numpy() - - y = merge_masks(y_l, y_r, full_img.size[0]) - yy = dense_crf(np.array(full_img).astype(np.uint8), y) - - return yy > 0.5 - - -if __name__ == "__main__": +def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE', @@ -61,19 +77,22 @@ if __name__ == "__main__": parser.add_argument('--no-save', '-n', action='store_false', help="Do not save the output masks", default=False) + parser.add_argument('--no-crf', '-r', action='store_false', + help="Do not use dense CRF postprocessing", + default=False) + parser.add_argument('--mask-threshold', '-t', type=float, + help="Minimum probability value to consider a mask pixel white", + default=0.5) + parser.add_argument('--scale', '-s', type=float, + help="Scale factor for the input images", + default=0.5) - args = parser.parse_args() - print("Using model file : {}".format(args.model)) - net = UNet(3, 1) - if not args.cpu: - print("Using CUDA version of the net, prepare your GPU !") - net.cuda() - else: - net.cpu() - print("Using CPU version of the net, this may be very slow") + return parser.parse_args() +def get_output_filenames(args): in_files = args.input out_files = [] + if not args.output: for f in in_files: pathsplit = os.path.splitext(f) @@ -84,32 +103,52 @@ if __name__ == "__main__": else: out_files = args.output - print("Loading model ...") - net.load_state_dict(torch.load(args.model)) + return out_files + +def mask_to_image(mask): + return Image.fromarray((mask * 255).astype(np.uint8)) + +if __name__ == "__main__": + args = get_args() + in_files = args.input + out_files = get_output_filenames(args) + + net = UNet(n_channels=3, n_classes=1) + + print("Loading model {}".format(args.model)) + + if not args.cpu: + print("Using CUDA version of the net, prepare your GPU !") + net.cuda() + net.load_state_dict(torch.load(args.model)) + else: + net.cpu() + net.load_state_dict(torch.load(args.model, map_location='cpu')) + print("Using CPU version of the net, this may be very slow") + print("Model loaded !") for i, fn in enumerate(in_files): print("\nPredicting image {} ...".format(fn)) + img = Image.open(fn) - out = predict_img(net, img, not args.cpu) + if img.size[0] < img.size[1]: + print("Error: image height larger than the width") + + mask = predict_img(net=net, + full_img=img, + scale_factor=args.scale, + out_threshold=args.mask_threshold, + use_dense_crf= not args.no_crf, + use_gpu=not args.cpu) if args.viz: - print("Vizualising results for image {}, close to continue ..." - .format(fn)) - - fig = plt.figure() - a = fig.add_subplot(1, 2, 1) - a.set_title('Input image') - plt.imshow(img) - - b = fig.add_subplot(1, 2, 2) - b.set_title('Output mask') - plt.imshow(out) - - plt.show() + print("Visualizing results for image {}, close to continue ...".format(fn)) + plot_img_and_mask(img, mask) if not args.no_save: out_fn = out_files[i] - result = Image.fromarray((out * 255).astype(numpy.uint8)) + result = mask_to_image(mask) result.save(out_files[i]) + print("Mask saved to {}".format(out_files[i])) diff --git a/submit.py b/submit.py index 93d197e..12f26cf 100644 --- a/submit.py +++ b/submit.py @@ -1,10 +1,15 @@ -# used to predict all test images and encode results in a csv file +import os +from PIL import Image -from predict import * +import torch + +from predict import predict_img +from utils import rle_encode from unet import UNet def submit(net, gpu=False): + """Used for Kaggle submission: predicts and encode all test images""" dir = 'data/test/' N = len(list(os.listdir(dir))) diff --git a/train.py b/train.py index c332657..3b0e4d6 100644 --- a/train.py +++ b/train.py @@ -1,20 +1,27 @@ import sys +import os from optparse import OptionParser +import numpy as np import torch import torch.backends.cudnn as cudnn import torch.nn as nn import torch.nn.functional as F from torch import optim -from torch.autograd import Variable from eval import eval_net from unet import UNet -from utils import * +from utils import get_ids, split_ids, split_train_val, get_imgs_and_masks, batch +def train_net(net, + epochs=5, + batch_size=1, + lr=0.1, + val_percent=0.05, + save_cp=True, + gpu=False, + img_scale=0.5): -def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05, - cp=True, gpu=False): dir_img = 'data/train/' dir_mask = 'data/train_masks/' dir_checkpoint = 'checkpoints/' @@ -34,69 +41,66 @@ def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05, Checkpoints: {} CUDA: {} '''.format(epochs, batch_size, lr, len(iddataset['train']), - len(iddataset['val']), str(cp), str(gpu))) + len(iddataset['val']), str(save_cp), str(gpu))) N_train = len(iddataset['train']) optimizer = optim.SGD(net.parameters(), - lr=lr, momentum=0.9, weight_decay=0.0005) + lr=lr, + momentum=0.9, + weight_decay=0.0005) + criterion = nn.BCELoss() for epoch in range(epochs): print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) # reset the generators - train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask) - val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask) + train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale) + val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale) epoch_loss = 0 + for i, b in enumerate(batch(train, batch_size)): + imgs = np.array([i[0] for i in b]).astype(np.float32) + true_masks = np.array([i[1] for i in b]) + + imgs = torch.from_numpy(imgs) + true_masks = torch.from_numpy(true_masks) + + if gpu: + imgs = imgs.cuda() + true_masks = true_masks.cuda() + + masks_pred = net(imgs) + masks_probs = F.sigmoid(masks_pred) + masks_probs_flat = masks_probs.view(-1) + + true_masks_flat = true_masks.view(-1) + + loss = criterion(masks_probs_flat, true_masks_flat) + epoch_loss += loss.item() + + print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item())) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + print('Epoch finished ! Loss: {}'.format(epoch_loss / i)) + if 1: val_dice = eval_net(net, val, gpu) print('Validation Dice Coeff: {}'.format(val_dice)) - for i, b in enumerate(batch(train, batch_size)): - X = np.array([i[0] for i in b]) - y = np.array([i[1] for i in b]) - - X = torch.FloatTensor(X) - y = torch.ByteTensor(y) - - if gpu: - X = Variable(X).cuda() - y = Variable(y).cuda() - else: - X = Variable(X) - y = Variable(y) - - y_pred = net(X) - probs = F.sigmoid(y_pred) - probs_flat = probs.view(-1) - - y_flat = y.view(-1) - - loss = criterion(probs_flat, y_flat.float()) - epoch_loss += loss.data[0] - - print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, - loss.data[0])) - - optimizer.zero_grad() - - loss.backward() - - optimizer.step() - - print('Epoch finished ! Loss: {}'.format(epoch_loss / i)) - - if cp: + if save_cp: torch.save(net.state_dict(), dir_checkpoint + 'CP{}.pth'.format(epoch + 1)) - print('Checkpoint {} saved !'.format(epoch + 1)) -if __name__ == '__main__': + +def get_args(): parser = OptionParser() parser.add_option('-e', '--epochs', dest='epochs', default=5, type='int', help='number of epochs') @@ -108,22 +112,32 @@ if __name__ == '__main__': default=False, help='use cuda') parser.add_option('-c', '--load', dest='load', default=False, help='load file model') + parser.add_option('-s', '--scale', dest='scale', type='float', + default=0.5, help='downscaling factor of the images') (options, args) = parser.parse_args() + return options - net = UNet(3, 1) +if __name__ == '__main__': + args = get_args() - if options.load: - net.load_state_dict(torch.load(options.load)) - print('Model loaded from {}'.format(options.load)) + net = UNet(n_channels=3, n_classes=1) - if options.gpu: + if args.load: + net.load_state_dict(torch.load(args.load)) + print('Model loaded from {}'.format(args.load)) + + if args.gpu: net.cuda() - cudnn.benchmark = True + # cudnn.benchmark = True # faster convolutions, but more memory try: - train_net(net, options.epochs, options.batchsize, options.lr, - gpu=options.gpu) + train_net(net=net, + epochs=args.epochs, + batch_size=args.batchsize, + lr=args.lr, + gpu=args.gpu, + img_scale=args.scale) except KeyboardInterrupt: torch.save(net.state_dict(), 'INTERRUPTED.pth') print('Saved interrupt') diff --git a/unet/unet_model.py b/unet/unet_model.py index 4afb8dd..a09ee5b 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -1,14 +1,7 @@ -#!/usr/bin/python # full assembly of the sub-parts to form the complete net -import torch -import torch.nn as nn -import torch.nn.functional as F - -# python 3 confusing imports :( from .unet_parts import * - class UNet(nn.Module): def __init__(self, n_channels, n_classes): super(UNet, self).__init__() diff --git a/unet/unet_parts.py b/unet/unet_parts.py index c7128d0..7fcadc7 100644 --- a/unet/unet_parts.py +++ b/unet/unet_parts.py @@ -1,5 +1,3 @@ -#!/usr/bin/python - # sub-parts of the U-Net model import torch @@ -53,9 +51,9 @@ class up(nn.Module): super(up, self).__init__() # would be a nice idea if the upsampling could be learned too, - #  but my machine do not have enough memory to handle all those weights + # but my machine do not have enough memory to handle all those weights if bilinear: - self.up = nn.UpsamplingBilinear2d(scale_factor=2) + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) else: self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) diff --git a/utils/crf.py b/utils/crf.py index 5ee718f..8a79953 100644 --- a/utils/crf.py +++ b/utils/crf.py @@ -1,7 +1,6 @@ import numpy as np import pydensecrf.densecrf as dcrf - def dense_crf(img, output_probs): h = output_probs.shape[0] w = output_probs.shape[1] diff --git a/utils/data_vis.py b/utils/data_vis.py index 365e4d1..4ec2f60 100644 --- a/utils/data_vis.py +++ b/utils/data_vis.py @@ -1,13 +1,12 @@ import matplotlib.pyplot as plt - -def plot_img_mask(img, mask): +def plot_img_and_mask(img, mask): fig = plt.figure() + a = fig.add_subplot(1, 2, 1) + a.set_title('Input image') + plt.imshow(img) - ax1 = fig.add_subplot(1, 3, 1) - ax1.imshow(img) - - ax2 = fig.add_subplot(1, 3, 2) - ax2.imshow(mask) - - plt.show() + b = fig.add_subplot(1, 2, 2) + b.set_title('Output mask') + plt.imshow(mask) + plt.show() \ No newline at end of file diff --git a/utils/load.py b/utils/load.py index 5ab7f80..8317ffc 100644 --- a/utils/load.py +++ b/utils/load.py @@ -3,12 +3,11 @@ # cropped images and masks import os -from functools import partial import numpy as np from PIL import Image -from .utils import resize_and_crop, get_square, normalize +from .utils import resize_and_crop, get_square, normalize, hwc_to_chw def get_ids(dir): @@ -21,23 +20,22 @@ def split_ids(ids, n=2): return ((id, i) for i in range(n) for id in ids) -def to_cropped_imgs(ids, dir, suffix): +def to_cropped_imgs(ids, dir, suffix, scale): """From a list of tuples, returns the correct cropped img""" for id, pos in ids: - im = resize_and_crop(Image.open(dir + id + suffix)) + im = resize_and_crop(Image.open(dir + id + suffix), scale=scale) yield get_square(im, pos) - -def get_imgs_and_masks(ids, dir_img, dir_mask): +def get_imgs_and_masks(ids, dir_img, dir_mask, scale): """Return all the couples (img, mask)""" - imgs = to_cropped_imgs(ids, dir_img, '.jpg') + imgs = to_cropped_imgs(ids, dir_img, '.jpg', scale) # need to transform from HWC to CHW - imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs) + imgs_switched = map(hwc_to_chw, imgs) imgs_normalized = map(normalize, imgs_switched) - masks = to_cropped_imgs(ids, dir_mask, '_mask.gif') + masks = to_cropped_imgs(ids, dir_mask, '_mask.gif', scale) return zip(imgs_normalized, masks) diff --git a/utils/utils.py b/utils/utils.py index 9b26506..830f1ce 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,17 +1,20 @@ import random - import numpy as np def get_square(img, pos): - """Extract a left or a right square from PILimg shape : (H, W, C))""" - img = np.array(img) + """Extract a left or a right square from ndarray shape : (H, W, C))""" h = img.shape[0] if pos == 0: return img[:, :h] else: return img[:, -h:] +def split_img_into_squares(img): + return get_square(img, 0), get_square(img, 1) + +def hwc_to_chw(img): + return np.transpose(img, axes=[2, 0, 1]) def resize_and_crop(pilimg, scale=0.5, final_height=None): w = pilimg.size[0] @@ -26,8 +29,7 @@ def resize_and_crop(pilimg, scale=0.5, final_height=None): img = pilimg.resize((newW, newH)) img = img.crop((0, diff // 2, newW, newH - diff // 2)) - return img - + return np.array(img, dtype=np.float32) def batch(iterable, batch_size): """Yields lists by batch""" @@ -41,7 +43,6 @@ def batch(iterable, batch_size): if len(b) > 0: yield b - def split_train_val(dataset, val_percent=0.05): dataset = list(dataset) length = len(dataset) @@ -53,18 +54,17 @@ def split_train_val(dataset, val_percent=0.05): def normalize(x): return x / 255 - def merge_masks(img1, img2, full_w): h = img1.shape[0] new = np.zeros((h, full_w), np.float32) - new[:, :full_w // 2 + 1] = img1[:, :full_w // 2 + 1] new[:, full_w // 2 + 1:] = img2[:, -(full_w // 2 - 1):] return new +# credits to https://stackoverflow.com/users/6076729/manuel-lagunas def rle_encode(mask_image): pixels = mask_image.flatten() # We avoid issues with '1' at the start or end (at the corners of