From 4063565295aae8f739b006ac5ab814c12bd8a81e Mon Sep 17 00:00:00 2001 From: milesial Date: Thu, 17 Aug 2017 21:16:19 +0200 Subject: [PATCH] Created a basic train loop + changed a bit loss and utils --- .gitignore | 2 + data_vis.py | 1 + load.py | 31 +++++++-------- myloss.py | 44 ++++++++++++++------- train.py | 105 ++++++++++++++++++++++++++++++++++++++++++++++++++ unet_model.py | 1 + unet_parts.py | 9 ++++- utils.py | 42 ++++++++++++++++---- 8 files changed, 195 insertions(+), 40 deletions(-) diff --git a/.gitignore b/.gitignore index a014da0..025cf91 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ *.pyc data/ __pycache__/ +checkpoints/ *.pth + diff --git a/data_vis.py b/data_vis.py index 714acc5..365e4d1 100644 --- a/data_vis.py +++ b/data_vis.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt + def plot_img_mask(img, mask): fig = plt.figure() diff --git a/load.py b/load.py index a981f03..c4418c1 100644 --- a/load.py +++ b/load.py @@ -1,47 +1,42 @@ + +# +# load.py : utils on generators / lists of ids to transform from strings to +# cropped images and masks + import os -import random import numpy as np + from PIL import Image from functools import partial -from utils import resize_and_crop, get_square +from utils import resize_and_crop, get_square, normalize def get_ids(dir): """Returns a list of the ids in the directory""" return (f[:-4] for f in os.listdir(dir)) + def split_ids(ids, n=2): """Split each id in n, creating n tuples (id, k) for each id""" return ((id, i) for i in range(n) for id in ids) -def shuffle_ids(ids): - """Returns a shuffle list od the ids""" - lst = list(ids) - random.shuffle(lst) - return lst def to_cropped_imgs(ids, dir, suffix): - """From a list of tuples, returns the correct cropped img (left or right)""" + """From a list of tuples, returns the correct cropped img""" for id, pos in ids: im = resize_and_crop(Image.open(dir + id + suffix)) yield get_square(im, pos) - -def get_imgs_and_masks(): - """From the list of ids, return the couples (img, mask)""" - dir_img = 'data/train/' - dir_mask = 'data/train_masks/' - - ids = get_ids(dir_img) - ids = split_ids(ids) - ids = shuffle_ids(ids) +def get_imgs_and_masks(ids, dir_img, dir_mask): + """Return all the couples (img, mask)""" imgs = to_cropped_imgs(ids, dir_img, '.jpg') # need to transform from HWC to CHW imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs) + imgs_normalized = map(normalize, imgs_switched) masks = to_cropped_imgs(ids, dir_mask, '_mask.gif') - return zip(imgs_switched, masks) + return zip(imgs_normalized, masks) diff --git a/myloss.py b/myloss.py index e65f0f1..2c26c39 100644 --- a/myloss.py +++ b/myloss.py @@ -1,34 +1,52 @@ + +# +# myloss.py : implementation of the Dice coeff and the associated loss +# + import torch -from torch.nn.modules.loss import _Loss -from torch.autograd import Function import torch.nn.functional as F +from torch.nn.modules.loss import _Loss +from torch.autograd import Function, Variable + + class DiceCoeff(Function): + """Dice coeff for individual examples""" + def forward(self, input, target): + self.save_for_backward(input, target) + self.inter = torch.dot(input, target) + 0.0001 + self.union = torch.sum(input) + torch.sum(target) + 0.0001 - def forward(ctx, input, target): - ctx.save_for_backward(input, target) - ctx.inter = torch.dot(input, target) + 0.0001 - ctx.union = torch.sum(input) + torch.sum(target) + 0.0001 - - t = 2*ctx.inter.float()/ctx.union.float() + t = 2*self.inter.float()/self.union.float() return t # This function has only a single output, so it gets only one gradient - def backward(ctx, grad_output): + def backward(self, grad_output): - input, target = ctx.saved_variables + input, target = self.saved_variables grad_input = grad_target = None if self.needs_input_grad[0]: - grad_input = grad_output * 2 * (target * ctx.union + ctx.inter) \ - / ctx.union * ctx.union + grad_input = grad_output * 2 * (target * self.union + self.inter) \ + / self.union * self.union if self.needs_input_grad[1]: grad_target = None return grad_input, grad_target + def dice_coeff(input, target): - return DiceCoeff().forward(input, target) + """Dice coeff for batches""" + if input.is_cuda: + s = Variable(torch.FloatTensor(1).cuda().zero_()) + else: + s = Variable(torch.FloatTensor(1).zero_()) + + for i, c in enumerate(zip(input, target)): + s = s + DiceCoeff().forward(c[0], c[1]) + + return s / (i+1) + class DiceLoss(_Loss): def forward(self, input, target): diff --git a/train.py b/train.py index e69de29..2dedee2 100644 --- a/train.py +++ b/train.py @@ -0,0 +1,105 @@ +import torch + +from load import * +from data_vis import * +from utils import split_train_val, batch +from myloss import DiceLoss +from unet_model import UNet +from torch.autograd import Variable +from torch import optim +from optparse import OptionParser + + +def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05, + cp=True, gpu=False): + dir_img = 'data/train/' + dir_mask = 'data/train_masks/' + dir_checkpoint = 'checkpoints/' + + # get ids + ids = get_ids(dir_img) + ids = split_ids(ids) + + iddataset = split_train_val(ids, val_percent) + + print(''' + Starting training: + Epochs: {} + Batch size: {} + Learning rate: {} + Training size: {} + Validation size: {} + Checkpoints: {} + CUDA: {} + '''.format(epochs, batch_size, lr, len(iddataset['train']), + len(iddataset['val']), str(cp), str(gpu))) + + N_train = len(iddataset['train']) + + train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask) + val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask) + + optimizer = optim.Adam(net.parameters(), lr=lr) + criterion = DiceLoss() + + for epoch in range(epochs): + print('Starting epoch {}/{}.'.format(epoch+1, epochs)) + + epoch_loss = 0 + + for i, b in enumerate(batch(train, batch_size)): + X = np.array([i[0] for i in b]) + y = np.array([i[1] for i in b]) + + X = torch.FloatTensor(X) + y = torch.ByteTensor(y) + + if gpu: + X = Variable(X).cuda() + y = Variable(y).cuda() + else: + X = Variable(X) + y = Variable(y) + + optimizer.zero_grad() + + y_pred = net(X) + + loss = criterion(y_pred, y.float()) + epoch_loss += loss.data[0] + + print('{0:.4f} --- loss: {1:.6f}'.format(i*batch_size/N_train, + loss.data[0])) + + loss.backward() + optimizer.step() + + print('Epoch finished ! Loss: {}'.format(epoch_loss/i)) + + if cp: + torch.save(net.state_dict(), + dir_checkpoint + 'CP{}.pth'.format(epoch+1)) + + print('Checkpoint {} saved !'.format(epoch+1)) + + +parser = OptionParser() +parser.add_option("-e", "--epochs", dest="epochs", default=5, type="int", + help="number of epochs") +parser.add_option("-b", "--batch-size", dest="batchsize", default=10, + type="int", help="batch size") +parser.add_option("-l", "--learning-rate", dest="lr", default=0.1, + type="int", help="learning rate") +parser.add_option("-g", "--gpu", action="store_true", dest="gpu", + default=False, help="use cuda") +parser.add_option("-n", "--ngpu", action="store_false", dest="gpu", + default=False, help="use cuda") + + +(options, args) = parser.parse_args() + +net = UNet(3, 1) +if options.gpu: + net.cuda() + +train_net(net, options.epochs, options.batchsize, options.lr, gpu=options.gpu) diff --git a/unet_model.py b/unet_model.py index 6d5dc39..5129e64 100644 --- a/unet_model.py +++ b/unet_model.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from unet_parts import * + class UNet(nn.Module): def __init__(self, n_channels, n_classes): super(UNet, self).__init__() diff --git a/unet_parts.py b/unet_parts.py index 37ec10e..08fd8d0 100644 --- a/unet_parts.py +++ b/unet_parts.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F + class double_conv(nn.Module): def __init__(self, in_ch, out_ch): super(double_conv, self).__init__() @@ -13,10 +14,12 @@ class double_conv(nn.Module): nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU() ) + def forward(self, x): x = self.conv(x) return x + class inconv(nn.Module): def __init__(self, in_ch, out_ch): super(inconv, self).__init__() @@ -26,6 +29,7 @@ class inconv(nn.Module): x = self.conv(x) return x + class down(nn.Module): def __init__(self, in_ch, out_ch): super(down, self).__init__() @@ -38,15 +42,15 @@ class down(nn.Module): x = self.mpconv(x) return x + class up(nn.Module): def __init__(self, in_ch, out_ch): super(up, self).__init__() self.up = nn.UpsamplingBilinear2d(scale_factor=2) - #self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2) + # self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2) self.conv = double_conv(in_ch, out_ch) def forward(self, x1, x2): - x1 = self.up(x1) diffX = x1.size()[2] - x2.size()[2] diffY = x1.size()[3] - x2.size()[3] @@ -56,6 +60,7 @@ class up(nn.Module): x = self.conv(x) return x + class outconv(nn.Module): def __init__(self, in_ch, out_ch): super(outconv, self).__init__() diff --git a/utils.py b/utils.py index 9732526..b51cd74 100644 --- a/utils.py +++ b/utils.py @@ -1,26 +1,54 @@ import PIL import numpy as np +import random + def get_square(img, pos): - """Extract a left or a right square from PILimg""" - """shape : (H, W, C))""" + """Extract a left or a right square from PILimg shape : (H, W, C))""" img = np.array(img) - h = img.shape[0] - w = img.shape[1] - if pos == 0: return img[:, :h] else: return img[:, -h:] -def resize_and_crop(pilimg, scale=0.5, final_height=640): + +def resize_and_crop(pilimg, scale=0.2, final_height=None): w = pilimg.size[0] h = pilimg.size[1] newW = int(w * scale) newH = int(h * scale) - diff = newH - final_height + + if not final_height: + diff = 0 + else: + diff = newH - final_height img = pilimg.resize((newW, newH)) img = img.crop((0, diff // 2, newW, newH - diff // 2)) return img + + +def batch(iterable, batch_size): + """Yields lists by batch""" + b = [] + for i, t in enumerate(iterable): + b.append(t) + if (i+1) % batch_size == 0: + yield b + b = [] + + if len(b) > 0: + yield b + + +def split_train_val(dataset, val_percent=0.05): + dataset = list(dataset) + length = len(dataset) + n = int(length * val_percent) + random.shuffle(dataset) + return {'train': dataset[:-n], 'val': dataset[-n:]} + + +def normalize(x): + return x / 255