Created a basic train loop + changed a bit loss and utils

This commit is contained in:
milesial 2017-08-17 21:16:19 +02:00
parent 8332f891c3
commit 4063565295
8 changed files with 195 additions and 40 deletions

2
.gitignore vendored
View file

@ -1,4 +1,6 @@
*.pyc *.pyc
data/ data/
__pycache__/ __pycache__/
checkpoints/
*.pth *.pth

View file

@ -1,5 +1,6 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def plot_img_mask(img, mask): def plot_img_mask(img, mask):
fig = plt.figure() fig = plt.figure()

31
load.py
View file

@ -1,47 +1,42 @@
#
# load.py : utils on generators / lists of ids to transform from strings to
# cropped images and masks
import os import os
import random
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from functools import partial from functools import partial
from utils import resize_and_crop, get_square from utils import resize_and_crop, get_square, normalize
def get_ids(dir): def get_ids(dir):
"""Returns a list of the ids in the directory""" """Returns a list of the ids in the directory"""
return (f[:-4] for f in os.listdir(dir)) return (f[:-4] for f in os.listdir(dir))
def split_ids(ids, n=2): def split_ids(ids, n=2):
"""Split each id in n, creating n tuples (id, k) for each id""" """Split each id in n, creating n tuples (id, k) for each id"""
return ((id, i) for i in range(n) for id in ids) return ((id, i) for i in range(n) for id in ids)
def shuffle_ids(ids):
"""Returns a shuffle list od the ids"""
lst = list(ids)
random.shuffle(lst)
return lst
def to_cropped_imgs(ids, dir, suffix): def to_cropped_imgs(ids, dir, suffix):
"""From a list of tuples, returns the correct cropped img (left or right)""" """From a list of tuples, returns the correct cropped img"""
for id, pos in ids: for id, pos in ids:
im = resize_and_crop(Image.open(dir + id + suffix)) im = resize_and_crop(Image.open(dir + id + suffix))
yield get_square(im, pos) yield get_square(im, pos)
def get_imgs_and_masks(ids, dir_img, dir_mask):
def get_imgs_and_masks(): """Return all the couples (img, mask)"""
"""From the list of ids, return the couples (img, mask)"""
dir_img = 'data/train/'
dir_mask = 'data/train_masks/'
ids = get_ids(dir_img)
ids = split_ids(ids)
ids = shuffle_ids(ids)
imgs = to_cropped_imgs(ids, dir_img, '.jpg') imgs = to_cropped_imgs(ids, dir_img, '.jpg')
# need to transform from HWC to CHW # need to transform from HWC to CHW
imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs) imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs)
imgs_normalized = map(normalize, imgs_switched)
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif') masks = to_cropped_imgs(ids, dir_mask, '_mask.gif')
return zip(imgs_switched, masks) return zip(imgs_normalized, masks)

View file

@ -1,34 +1,52 @@
#
# myloss.py : implementation of the Dice coeff and the associated loss
#
import torch import torch
from torch.nn.modules.loss import _Loss
from torch.autograd import Function
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from torch.autograd import Function, Variable
class DiceCoeff(Function): class DiceCoeff(Function):
"""Dice coeff for individual examples"""
def forward(self, input, target):
self.save_for_backward(input, target)
self.inter = torch.dot(input, target) + 0.0001
self.union = torch.sum(input) + torch.sum(target) + 0.0001
def forward(ctx, input, target): t = 2*self.inter.float()/self.union.float()
ctx.save_for_backward(input, target)
ctx.inter = torch.dot(input, target) + 0.0001
ctx.union = torch.sum(input) + torch.sum(target) + 0.0001
t = 2*ctx.inter.float()/ctx.union.float()
return t return t
# This function has only a single output, so it gets only one gradient # This function has only a single output, so it gets only one gradient
def backward(ctx, grad_output): def backward(self, grad_output):
input, target = ctx.saved_variables input, target = self.saved_variables
grad_input = grad_target = None grad_input = grad_target = None
if self.needs_input_grad[0]: if self.needs_input_grad[0]:
grad_input = grad_output * 2 * (target * ctx.union + ctx.inter) \ grad_input = grad_output * 2 * (target * self.union + self.inter) \
/ ctx.union * ctx.union / self.union * self.union
if self.needs_input_grad[1]: if self.needs_input_grad[1]:
grad_target = None grad_target = None
return grad_input, grad_target return grad_input, grad_target
def dice_coeff(input, target): def dice_coeff(input, target):
return DiceCoeff().forward(input, target) """Dice coeff for batches"""
if input.is_cuda:
s = Variable(torch.FloatTensor(1).cuda().zero_())
else:
s = Variable(torch.FloatTensor(1).zero_())
for i, c in enumerate(zip(input, target)):
s = s + DiceCoeff().forward(c[0], c[1])
return s / (i+1)
class DiceLoss(_Loss): class DiceLoss(_Loss):
def forward(self, input, target): def forward(self, input, target):

105
train.py
View file

@ -0,0 +1,105 @@
import torch
from load import *
from data_vis import *
from utils import split_train_val, batch
from myloss import DiceLoss
from unet_model import UNet
from torch.autograd import Variable
from torch import optim
from optparse import OptionParser
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
cp=True, gpu=False):
dir_img = 'data/train/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoints/'
# get ids
ids = get_ids(dir_img)
ids = split_ids(ids)
iddataset = split_train_val(ids, val_percent)
print('''
Starting training:
Epochs: {}
Batch size: {}
Learning rate: {}
Training size: {}
Validation size: {}
Checkpoints: {}
CUDA: {}
'''.format(epochs, batch_size, lr, len(iddataset['train']),
len(iddataset['val']), str(cp), str(gpu)))
N_train = len(iddataset['train'])
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask)
optimizer = optim.Adam(net.parameters(), lr=lr)
criterion = DiceLoss()
for epoch in range(epochs):
print('Starting epoch {}/{}.'.format(epoch+1, epochs))
epoch_loss = 0
for i, b in enumerate(batch(train, batch_size)):
X = np.array([i[0] for i in b])
y = np.array([i[1] for i in b])
X = torch.FloatTensor(X)
y = torch.ByteTensor(y)
if gpu:
X = Variable(X).cuda()
y = Variable(y).cuda()
else:
X = Variable(X)
y = Variable(y)
optimizer.zero_grad()
y_pred = net(X)
loss = criterion(y_pred, y.float())
epoch_loss += loss.data[0]
print('{0:.4f} --- loss: {1:.6f}'.format(i*batch_size/N_train,
loss.data[0]))
loss.backward()
optimizer.step()
print('Epoch finished ! Loss: {}'.format(epoch_loss/i))
if cp:
torch.save(net.state_dict(),
dir_checkpoint + 'CP{}.pth'.format(epoch+1))
print('Checkpoint {} saved !'.format(epoch+1))
parser = OptionParser()
parser.add_option("-e", "--epochs", dest="epochs", default=5, type="int",
help="number of epochs")
parser.add_option("-b", "--batch-size", dest="batchsize", default=10,
type="int", help="batch size")
parser.add_option("-l", "--learning-rate", dest="lr", default=0.1,
type="int", help="learning rate")
parser.add_option("-g", "--gpu", action="store_true", dest="gpu",
default=False, help="use cuda")
parser.add_option("-n", "--ngpu", action="store_false", dest="gpu",
default=False, help="use cuda")
(options, args) = parser.parse_args()
net = UNet(3, 1)
if options.gpu:
net.cuda()
train_net(net, options.epochs, options.batchsize, options.lr, gpu=options.gpu)

View file

@ -4,6 +4,7 @@ import torch.nn.functional as F
from unet_parts import * from unet_parts import *
class UNet(nn.Module): class UNet(nn.Module):
def __init__(self, n_channels, n_classes): def __init__(self, n_channels, n_classes):
super(UNet, self).__init__() super(UNet, self).__init__()

View file

@ -4,6 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
class double_conv(nn.Module): class double_conv(nn.Module):
def __init__(self, in_ch, out_ch): def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__() super(double_conv, self).__init__()
@ -13,10 +14,12 @@ class double_conv(nn.Module):
nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.ReLU() nn.ReLU()
) )
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
return x return x
class inconv(nn.Module): class inconv(nn.Module):
def __init__(self, in_ch, out_ch): def __init__(self, in_ch, out_ch):
super(inconv, self).__init__() super(inconv, self).__init__()
@ -26,6 +29,7 @@ class inconv(nn.Module):
x = self.conv(x) x = self.conv(x)
return x return x
class down(nn.Module): class down(nn.Module):
def __init__(self, in_ch, out_ch): def __init__(self, in_ch, out_ch):
super(down, self).__init__() super(down, self).__init__()
@ -38,15 +42,15 @@ class down(nn.Module):
x = self.mpconv(x) x = self.mpconv(x)
return x return x
class up(nn.Module): class up(nn.Module):
def __init__(self, in_ch, out_ch): def __init__(self, in_ch, out_ch):
super(up, self).__init__() super(up, self).__init__()
self.up = nn.UpsamplingBilinear2d(scale_factor=2) self.up = nn.UpsamplingBilinear2d(scale_factor=2)
#self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2) # self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
self.conv = double_conv(in_ch, out_ch) self.conv = double_conv(in_ch, out_ch)
def forward(self, x1, x2): def forward(self, x1, x2):
x1 = self.up(x1) x1 = self.up(x1)
diffX = x1.size()[2] - x2.size()[2] diffX = x1.size()[2] - x2.size()[2]
diffY = x1.size()[3] - x2.size()[3] diffY = x1.size()[3] - x2.size()[3]
@ -56,6 +60,7 @@ class up(nn.Module):
x = self.conv(x) x = self.conv(x)
return x return x
class outconv(nn.Module): class outconv(nn.Module):
def __init__(self, in_ch, out_ch): def __init__(self, in_ch, out_ch):
super(outconv, self).__init__() super(outconv, self).__init__()

View file

@ -1,26 +1,54 @@
import PIL import PIL
import numpy as np import numpy as np
import random
def get_square(img, pos): def get_square(img, pos):
"""Extract a left or a right square from PILimg""" """Extract a left or a right square from PILimg shape : (H, W, C))"""
"""shape : (H, W, C))"""
img = np.array(img) img = np.array(img)
h = img.shape[0] h = img.shape[0]
w = img.shape[1]
if pos == 0: if pos == 0:
return img[:, :h] return img[:, :h]
else: else:
return img[:, -h:] return img[:, -h:]
def resize_and_crop(pilimg, scale=0.5, final_height=640):
def resize_and_crop(pilimg, scale=0.2, final_height=None):
w = pilimg.size[0] w = pilimg.size[0]
h = pilimg.size[1] h = pilimg.size[1]
newW = int(w * scale) newW = int(w * scale)
newH = int(h * scale) newH = int(h * scale)
if not final_height:
diff = 0
else:
diff = newH - final_height diff = newH - final_height
img = pilimg.resize((newW, newH)) img = pilimg.resize((newW, newH))
img = img.crop((0, diff // 2, newW, newH - diff // 2)) img = img.crop((0, diff // 2, newW, newH - diff // 2))
return img return img
def batch(iterable, batch_size):
"""Yields lists by batch"""
b = []
for i, t in enumerate(iterable):
b.append(t)
if (i+1) % batch_size == 0:
yield b
b = []
if len(b) > 0:
yield b
def split_train_val(dataset, val_percent=0.05):
dataset = list(dataset)
length = len(dataset)
n = int(length * val_percent)
random.shuffle(dataset)
return {'train': dataset[:-n], 'val': dataset[-n:]}
def normalize(x):
return x / 255