Created a basic train loop + changed a bit loss and utils

This commit is contained in:
milesial 2017-08-17 21:16:19 +02:00
parent 8332f891c3
commit 4063565295
8 changed files with 195 additions and 40 deletions

2
.gitignore vendored
View file

@ -1,4 +1,6 @@
*.pyc
data/
__pycache__/
checkpoints/
*.pth

View file

@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
def plot_img_mask(img, mask):
fig = plt.figure()

31
load.py
View file

@ -1,47 +1,42 @@
#
# load.py : utils on generators / lists of ids to transform from strings to
# cropped images and masks
import os
import random
import numpy as np
from PIL import Image
from functools import partial
from utils import resize_and_crop, get_square
from utils import resize_and_crop, get_square, normalize
def get_ids(dir):
"""Returns a list of the ids in the directory"""
return (f[:-4] for f in os.listdir(dir))
def split_ids(ids, n=2):
"""Split each id in n, creating n tuples (id, k) for each id"""
return ((id, i) for i in range(n) for id in ids)
def shuffle_ids(ids):
"""Returns a shuffle list od the ids"""
lst = list(ids)
random.shuffle(lst)
return lst
def to_cropped_imgs(ids, dir, suffix):
"""From a list of tuples, returns the correct cropped img (left or right)"""
"""From a list of tuples, returns the correct cropped img"""
for id, pos in ids:
im = resize_and_crop(Image.open(dir + id + suffix))
yield get_square(im, pos)
def get_imgs_and_masks():
"""From the list of ids, return the couples (img, mask)"""
dir_img = 'data/train/'
dir_mask = 'data/train_masks/'
ids = get_ids(dir_img)
ids = split_ids(ids)
ids = shuffle_ids(ids)
def get_imgs_and_masks(ids, dir_img, dir_mask):
"""Return all the couples (img, mask)"""
imgs = to_cropped_imgs(ids, dir_img, '.jpg')
# need to transform from HWC to CHW
imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs)
imgs_normalized = map(normalize, imgs_switched)
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif')
return zip(imgs_switched, masks)
return zip(imgs_normalized, masks)

View file

@ -1,34 +1,52 @@
#
# myloss.py : implementation of the Dice coeff and the associated loss
#
import torch
from torch.nn.modules.loss import _Loss
from torch.autograd import Function
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from torch.autograd import Function, Variable
class DiceCoeff(Function):
"""Dice coeff for individual examples"""
def forward(self, input, target):
self.save_for_backward(input, target)
self.inter = torch.dot(input, target) + 0.0001
self.union = torch.sum(input) + torch.sum(target) + 0.0001
def forward(ctx, input, target):
ctx.save_for_backward(input, target)
ctx.inter = torch.dot(input, target) + 0.0001
ctx.union = torch.sum(input) + torch.sum(target) + 0.0001
t = 2*ctx.inter.float()/ctx.union.float()
t = 2*self.inter.float()/self.union.float()
return t
# This function has only a single output, so it gets only one gradient
def backward(ctx, grad_output):
def backward(self, grad_output):
input, target = ctx.saved_variables
input, target = self.saved_variables
grad_input = grad_target = None
if self.needs_input_grad[0]:
grad_input = grad_output * 2 * (target * ctx.union + ctx.inter) \
/ ctx.union * ctx.union
grad_input = grad_output * 2 * (target * self.union + self.inter) \
/ self.union * self.union
if self.needs_input_grad[1]:
grad_target = None
return grad_input, grad_target
def dice_coeff(input, target):
return DiceCoeff().forward(input, target)
"""Dice coeff for batches"""
if input.is_cuda:
s = Variable(torch.FloatTensor(1).cuda().zero_())
else:
s = Variable(torch.FloatTensor(1).zero_())
for i, c in enumerate(zip(input, target)):
s = s + DiceCoeff().forward(c[0], c[1])
return s / (i+1)
class DiceLoss(_Loss):
def forward(self, input, target):

105
train.py
View file

@ -0,0 +1,105 @@
import torch
from load import *
from data_vis import *
from utils import split_train_val, batch
from myloss import DiceLoss
from unet_model import UNet
from torch.autograd import Variable
from torch import optim
from optparse import OptionParser
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
cp=True, gpu=False):
dir_img = 'data/train/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoints/'
# get ids
ids = get_ids(dir_img)
ids = split_ids(ids)
iddataset = split_train_val(ids, val_percent)
print('''
Starting training:
Epochs: {}
Batch size: {}
Learning rate: {}
Training size: {}
Validation size: {}
Checkpoints: {}
CUDA: {}
'''.format(epochs, batch_size, lr, len(iddataset['train']),
len(iddataset['val']), str(cp), str(gpu)))
N_train = len(iddataset['train'])
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask)
optimizer = optim.Adam(net.parameters(), lr=lr)
criterion = DiceLoss()
for epoch in range(epochs):
print('Starting epoch {}/{}.'.format(epoch+1, epochs))
epoch_loss = 0
for i, b in enumerate(batch(train, batch_size)):
X = np.array([i[0] for i in b])
y = np.array([i[1] for i in b])
X = torch.FloatTensor(X)
y = torch.ByteTensor(y)
if gpu:
X = Variable(X).cuda()
y = Variable(y).cuda()
else:
X = Variable(X)
y = Variable(y)
optimizer.zero_grad()
y_pred = net(X)
loss = criterion(y_pred, y.float())
epoch_loss += loss.data[0]
print('{0:.4f} --- loss: {1:.6f}'.format(i*batch_size/N_train,
loss.data[0]))
loss.backward()
optimizer.step()
print('Epoch finished ! Loss: {}'.format(epoch_loss/i))
if cp:
torch.save(net.state_dict(),
dir_checkpoint + 'CP{}.pth'.format(epoch+1))
print('Checkpoint {} saved !'.format(epoch+1))
parser = OptionParser()
parser.add_option("-e", "--epochs", dest="epochs", default=5, type="int",
help="number of epochs")
parser.add_option("-b", "--batch-size", dest="batchsize", default=10,
type="int", help="batch size")
parser.add_option("-l", "--learning-rate", dest="lr", default=0.1,
type="int", help="learning rate")
parser.add_option("-g", "--gpu", action="store_true", dest="gpu",
default=False, help="use cuda")
parser.add_option("-n", "--ngpu", action="store_false", dest="gpu",
default=False, help="use cuda")
(options, args) = parser.parse_args()
net = UNet(3, 1)
if options.gpu:
net.cuda()
train_net(net, options.epochs, options.batchsize, options.lr, gpu=options.gpu)

View file

@ -4,6 +4,7 @@ import torch.nn.functional as F
from unet_parts import *
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()

View file

@ -4,6 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
class double_conv(nn.Module):
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
@ -13,10 +14,12 @@ class double_conv(nn.Module):
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.ReLU()
)
def forward(self, x):
x = self.conv(x)
return x
class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
@ -26,6 +29,7 @@ class inconv(nn.Module):
x = self.conv(x)
return x
class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
@ -38,6 +42,7 @@ class down(nn.Module):
x = self.mpconv(x)
return x
class up(nn.Module):
def __init__(self, in_ch, out_ch):
super(up, self).__init__()
@ -46,7 +51,6 @@ class up(nn.Module):
self.conv = double_conv(in_ch, out_ch)
def forward(self, x1, x2):
x1 = self.up(x1)
diffX = x1.size()[2] - x2.size()[2]
diffY = x1.size()[3] - x2.size()[3]
@ -56,6 +60,7 @@ class up(nn.Module):
x = self.conv(x)
return x
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()

View file

@ -1,26 +1,54 @@
import PIL
import numpy as np
import random
def get_square(img, pos):
"""Extract a left or a right square from PILimg"""
"""shape : (H, W, C))"""
"""Extract a left or a right square from PILimg shape : (H, W, C))"""
img = np.array(img)
h = img.shape[0]
w = img.shape[1]
if pos == 0:
return img[:, :h]
else:
return img[:, -h:]
def resize_and_crop(pilimg, scale=0.5, final_height=640):
def resize_and_crop(pilimg, scale=0.2, final_height=None):
w = pilimg.size[0]
h = pilimg.size[1]
newW = int(w * scale)
newH = int(h * scale)
if not final_height:
diff = 0
else:
diff = newH - final_height
img = pilimg.resize((newW, newH))
img = img.crop((0, diff // 2, newW, newH - diff // 2))
return img
def batch(iterable, batch_size):
"""Yields lists by batch"""
b = []
for i, t in enumerate(iterable):
b.append(t)
if (i+1) % batch_size == 0:
yield b
b = []
if len(b) > 0:
yield b
def split_train_val(dataset, val_percent=0.05):
dataset = list(dataset)
length = len(dataset)
n = int(length * val_percent)
random.shuffle(dataset)
return {'train': dataset[:-n], 'val': dataset[-n:]}
def normalize(x):
return x / 255