Created a basic train loop + changed a bit loss and utils
This commit is contained in:
parent
8332f891c3
commit
4063565295
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,4 +1,6 @@
|
|||
*.pyc
|
||||
data/
|
||||
__pycache__/
|
||||
checkpoints/
|
||||
*.pth
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def plot_img_mask(img, mask):
|
||||
fig = plt.figure()
|
||||
|
||||
|
|
31
load.py
31
load.py
|
@ -1,47 +1,42 @@
|
|||
|
||||
#
|
||||
# load.py : utils on generators / lists of ids to transform from strings to
|
||||
# cropped images and masks
|
||||
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
from functools import partial
|
||||
from utils import resize_and_crop, get_square
|
||||
from utils import resize_and_crop, get_square, normalize
|
||||
|
||||
|
||||
def get_ids(dir):
|
||||
"""Returns a list of the ids in the directory"""
|
||||
return (f[:-4] for f in os.listdir(dir))
|
||||
|
||||
|
||||
def split_ids(ids, n=2):
|
||||
"""Split each id in n, creating n tuples (id, k) for each id"""
|
||||
return ((id, i) for i in range(n) for id in ids)
|
||||
|
||||
def shuffle_ids(ids):
|
||||
"""Returns a shuffle list od the ids"""
|
||||
lst = list(ids)
|
||||
random.shuffle(lst)
|
||||
return lst
|
||||
|
||||
def to_cropped_imgs(ids, dir, suffix):
|
||||
"""From a list of tuples, returns the correct cropped img (left or right)"""
|
||||
"""From a list of tuples, returns the correct cropped img"""
|
||||
for id, pos in ids:
|
||||
im = resize_and_crop(Image.open(dir + id + suffix))
|
||||
yield get_square(im, pos)
|
||||
|
||||
|
||||
|
||||
def get_imgs_and_masks():
|
||||
"""From the list of ids, return the couples (img, mask)"""
|
||||
dir_img = 'data/train/'
|
||||
dir_mask = 'data/train_masks/'
|
||||
|
||||
ids = get_ids(dir_img)
|
||||
ids = split_ids(ids)
|
||||
ids = shuffle_ids(ids)
|
||||
def get_imgs_and_masks(ids, dir_img, dir_mask):
|
||||
"""Return all the couples (img, mask)"""
|
||||
|
||||
imgs = to_cropped_imgs(ids, dir_img, '.jpg')
|
||||
|
||||
# need to transform from HWC to CHW
|
||||
imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs)
|
||||
imgs_normalized = map(normalize, imgs_switched)
|
||||
|
||||
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif')
|
||||
|
||||
return zip(imgs_switched, masks)
|
||||
return zip(imgs_normalized, masks)
|
||||
|
|
44
myloss.py
44
myloss.py
|
@ -1,34 +1,52 @@
|
|||
|
||||
#
|
||||
# myloss.py : implementation of the Dice coeff and the associated loss
|
||||
#
|
||||
|
||||
import torch
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.autograd import Function
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.autograd import Function, Variable
|
||||
|
||||
|
||||
class DiceCoeff(Function):
|
||||
"""Dice coeff for individual examples"""
|
||||
def forward(self, input, target):
|
||||
self.save_for_backward(input, target)
|
||||
self.inter = torch.dot(input, target) + 0.0001
|
||||
self.union = torch.sum(input) + torch.sum(target) + 0.0001
|
||||
|
||||
def forward(ctx, input, target):
|
||||
ctx.save_for_backward(input, target)
|
||||
ctx.inter = torch.dot(input, target) + 0.0001
|
||||
ctx.union = torch.sum(input) + torch.sum(target) + 0.0001
|
||||
|
||||
t = 2*ctx.inter.float()/ctx.union.float()
|
||||
t = 2*self.inter.float()/self.union.float()
|
||||
return t
|
||||
|
||||
# This function has only a single output, so it gets only one gradient
|
||||
def backward(ctx, grad_output):
|
||||
def backward(self, grad_output):
|
||||
|
||||
input, target = ctx.saved_variables
|
||||
input, target = self.saved_variables
|
||||
grad_input = grad_target = None
|
||||
|
||||
if self.needs_input_grad[0]:
|
||||
grad_input = grad_output * 2 * (target * ctx.union + ctx.inter) \
|
||||
/ ctx.union * ctx.union
|
||||
grad_input = grad_output * 2 * (target * self.union + self.inter) \
|
||||
/ self.union * self.union
|
||||
if self.needs_input_grad[1]:
|
||||
grad_target = None
|
||||
|
||||
return grad_input, grad_target
|
||||
|
||||
|
||||
def dice_coeff(input, target):
|
||||
return DiceCoeff().forward(input, target)
|
||||
"""Dice coeff for batches"""
|
||||
if input.is_cuda:
|
||||
s = Variable(torch.FloatTensor(1).cuda().zero_())
|
||||
else:
|
||||
s = Variable(torch.FloatTensor(1).zero_())
|
||||
|
||||
for i, c in enumerate(zip(input, target)):
|
||||
s = s + DiceCoeff().forward(c[0], c[1])
|
||||
|
||||
return s / (i+1)
|
||||
|
||||
|
||||
class DiceLoss(_Loss):
|
||||
def forward(self, input, target):
|
||||
|
|
105
train.py
105
train.py
|
@ -0,0 +1,105 @@
|
|||
import torch
|
||||
|
||||
from load import *
|
||||
from data_vis import *
|
||||
from utils import split_train_val, batch
|
||||
from myloss import DiceLoss
|
||||
from unet_model import UNet
|
||||
from torch.autograd import Variable
|
||||
from torch import optim
|
||||
from optparse import OptionParser
|
||||
|
||||
|
||||
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
|
||||
cp=True, gpu=False):
|
||||
dir_img = 'data/train/'
|
||||
dir_mask = 'data/train_masks/'
|
||||
dir_checkpoint = 'checkpoints/'
|
||||
|
||||
# get ids
|
||||
ids = get_ids(dir_img)
|
||||
ids = split_ids(ids)
|
||||
|
||||
iddataset = split_train_val(ids, val_percent)
|
||||
|
||||
print('''
|
||||
Starting training:
|
||||
Epochs: {}
|
||||
Batch size: {}
|
||||
Learning rate: {}
|
||||
Training size: {}
|
||||
Validation size: {}
|
||||
Checkpoints: {}
|
||||
CUDA: {}
|
||||
'''.format(epochs, batch_size, lr, len(iddataset['train']),
|
||||
len(iddataset['val']), str(cp), str(gpu)))
|
||||
|
||||
N_train = len(iddataset['train'])
|
||||
|
||||
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
|
||||
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask)
|
||||
|
||||
optimizer = optim.Adam(net.parameters(), lr=lr)
|
||||
criterion = DiceLoss()
|
||||
|
||||
for epoch in range(epochs):
|
||||
print('Starting epoch {}/{}.'.format(epoch+1, epochs))
|
||||
|
||||
epoch_loss = 0
|
||||
|
||||
for i, b in enumerate(batch(train, batch_size)):
|
||||
X = np.array([i[0] for i in b])
|
||||
y = np.array([i[1] for i in b])
|
||||
|
||||
X = torch.FloatTensor(X)
|
||||
y = torch.ByteTensor(y)
|
||||
|
||||
if gpu:
|
||||
X = Variable(X).cuda()
|
||||
y = Variable(y).cuda()
|
||||
else:
|
||||
X = Variable(X)
|
||||
y = Variable(y)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
y_pred = net(X)
|
||||
|
||||
loss = criterion(y_pred, y.float())
|
||||
epoch_loss += loss.data[0]
|
||||
|
||||
print('{0:.4f} --- loss: {1:.6f}'.format(i*batch_size/N_train,
|
||||
loss.data[0]))
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print('Epoch finished ! Loss: {}'.format(epoch_loss/i))
|
||||
|
||||
if cp:
|
||||
torch.save(net.state_dict(),
|
||||
dir_checkpoint + 'CP{}.pth'.format(epoch+1))
|
||||
|
||||
print('Checkpoint {} saved !'.format(epoch+1))
|
||||
|
||||
|
||||
parser = OptionParser()
|
||||
parser.add_option("-e", "--epochs", dest="epochs", default=5, type="int",
|
||||
help="number of epochs")
|
||||
parser.add_option("-b", "--batch-size", dest="batchsize", default=10,
|
||||
type="int", help="batch size")
|
||||
parser.add_option("-l", "--learning-rate", dest="lr", default=0.1,
|
||||
type="int", help="learning rate")
|
||||
parser.add_option("-g", "--gpu", action="store_true", dest="gpu",
|
||||
default=False, help="use cuda")
|
||||
parser.add_option("-n", "--ngpu", action="store_false", dest="gpu",
|
||||
default=False, help="use cuda")
|
||||
|
||||
|
||||
(options, args) = parser.parse_args()
|
||||
|
||||
net = UNet(3, 1)
|
||||
if options.gpu:
|
||||
net.cuda()
|
||||
|
||||
train_net(net, options.epochs, options.batchsize, options.lr, gpu=options.gpu)
|
|
@ -4,6 +4,7 @@ import torch.nn.functional as F
|
|||
|
||||
from unet_parts import *
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, n_channels, n_classes):
|
||||
super(UNet, self).__init__()
|
||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class double_conv(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(double_conv, self).__init__()
|
||||
|
@ -13,10 +14,12 @@ class double_conv(nn.Module):
|
|||
nn.Conv2d(out_ch, out_ch, 3, padding=1),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class inconv(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(inconv, self).__init__()
|
||||
|
@ -26,6 +29,7 @@ class inconv(nn.Module):
|
|||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class down(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(down, self).__init__()
|
||||
|
@ -38,6 +42,7 @@ class down(nn.Module):
|
|||
x = self.mpconv(x)
|
||||
return x
|
||||
|
||||
|
||||
class up(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(up, self).__init__()
|
||||
|
@ -46,7 +51,6 @@ class up(nn.Module):
|
|||
self.conv = double_conv(in_ch, out_ch)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
|
||||
x1 = self.up(x1)
|
||||
diffX = x1.size()[2] - x2.size()[2]
|
||||
diffY = x1.size()[3] - x2.size()[3]
|
||||
|
@ -56,6 +60,7 @@ class up(nn.Module):
|
|||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class outconv(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(outconv, self).__init__()
|
||||
|
|
40
utils.py
40
utils.py
|
@ -1,26 +1,54 @@
|
|||
import PIL
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
def get_square(img, pos):
|
||||
"""Extract a left or a right square from PILimg"""
|
||||
"""shape : (H, W, C))"""
|
||||
"""Extract a left or a right square from PILimg shape : (H, W, C))"""
|
||||
img = np.array(img)
|
||||
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
|
||||
if pos == 0:
|
||||
return img[:, :h]
|
||||
else:
|
||||
return img[:, -h:]
|
||||
|
||||
def resize_and_crop(pilimg, scale=0.5, final_height=640):
|
||||
|
||||
def resize_and_crop(pilimg, scale=0.2, final_height=None):
|
||||
w = pilimg.size[0]
|
||||
h = pilimg.size[1]
|
||||
newW = int(w * scale)
|
||||
newH = int(h * scale)
|
||||
|
||||
if not final_height:
|
||||
diff = 0
|
||||
else:
|
||||
diff = newH - final_height
|
||||
|
||||
img = pilimg.resize((newW, newH))
|
||||
img = img.crop((0, diff // 2, newW, newH - diff // 2))
|
||||
return img
|
||||
|
||||
|
||||
def batch(iterable, batch_size):
|
||||
"""Yields lists by batch"""
|
||||
b = []
|
||||
for i, t in enumerate(iterable):
|
||||
b.append(t)
|
||||
if (i+1) % batch_size == 0:
|
||||
yield b
|
||||
b = []
|
||||
|
||||
if len(b) > 0:
|
||||
yield b
|
||||
|
||||
|
||||
def split_train_val(dataset, val_percent=0.05):
|
||||
dataset = list(dataset)
|
||||
length = len(dataset)
|
||||
n = int(length * val_percent)
|
||||
random.shuffle(dataset)
|
||||
return {'train': dataset[:-n], 'val': dataset[-n:]}
|
||||
|
||||
|
||||
def normalize(x):
|
||||
return x / 255
|
||||
|
|
Loading…
Reference in a new issue