mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
Created a basic train loop + changed a bit loss and utils
This commit is contained in:
parent
8332f891c3
commit
4063565295
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,4 +1,6 @@
|
||||||
*.pyc
|
*.pyc
|
||||||
data/
|
data/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
checkpoints/
|
||||||
*.pth
|
*.pth
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
def plot_img_mask(img, mask):
|
def plot_img_mask(img, mask):
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
|
|
||||||
|
|
31
load.py
31
load.py
|
@ -1,47 +1,42 @@
|
||||||
|
|
||||||
|
#
|
||||||
|
# load.py : utils on generators / lists of ids to transform from strings to
|
||||||
|
# cropped images and masks
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from utils import resize_and_crop, get_square
|
from utils import resize_and_crop, get_square, normalize
|
||||||
|
|
||||||
|
|
||||||
def get_ids(dir):
|
def get_ids(dir):
|
||||||
"""Returns a list of the ids in the directory"""
|
"""Returns a list of the ids in the directory"""
|
||||||
return (f[:-4] for f in os.listdir(dir))
|
return (f[:-4] for f in os.listdir(dir))
|
||||||
|
|
||||||
|
|
||||||
def split_ids(ids, n=2):
|
def split_ids(ids, n=2):
|
||||||
"""Split each id in n, creating n tuples (id, k) for each id"""
|
"""Split each id in n, creating n tuples (id, k) for each id"""
|
||||||
return ((id, i) for i in range(n) for id in ids)
|
return ((id, i) for i in range(n) for id in ids)
|
||||||
|
|
||||||
def shuffle_ids(ids):
|
|
||||||
"""Returns a shuffle list od the ids"""
|
|
||||||
lst = list(ids)
|
|
||||||
random.shuffle(lst)
|
|
||||||
return lst
|
|
||||||
|
|
||||||
def to_cropped_imgs(ids, dir, suffix):
|
def to_cropped_imgs(ids, dir, suffix):
|
||||||
"""From a list of tuples, returns the correct cropped img (left or right)"""
|
"""From a list of tuples, returns the correct cropped img"""
|
||||||
for id, pos in ids:
|
for id, pos in ids:
|
||||||
im = resize_and_crop(Image.open(dir + id + suffix))
|
im = resize_and_crop(Image.open(dir + id + suffix))
|
||||||
yield get_square(im, pos)
|
yield get_square(im, pos)
|
||||||
|
|
||||||
|
|
||||||
|
def get_imgs_and_masks(ids, dir_img, dir_mask):
|
||||||
def get_imgs_and_masks():
|
"""Return all the couples (img, mask)"""
|
||||||
"""From the list of ids, return the couples (img, mask)"""
|
|
||||||
dir_img = 'data/train/'
|
|
||||||
dir_mask = 'data/train_masks/'
|
|
||||||
|
|
||||||
ids = get_ids(dir_img)
|
|
||||||
ids = split_ids(ids)
|
|
||||||
ids = shuffle_ids(ids)
|
|
||||||
|
|
||||||
imgs = to_cropped_imgs(ids, dir_img, '.jpg')
|
imgs = to_cropped_imgs(ids, dir_img, '.jpg')
|
||||||
|
|
||||||
# need to transform from HWC to CHW
|
# need to transform from HWC to CHW
|
||||||
imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs)
|
imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs)
|
||||||
|
imgs_normalized = map(normalize, imgs_switched)
|
||||||
|
|
||||||
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif')
|
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif')
|
||||||
|
|
||||||
return zip(imgs_switched, masks)
|
return zip(imgs_normalized, masks)
|
||||||
|
|
44
myloss.py
44
myloss.py
|
@ -1,34 +1,52 @@
|
||||||
|
|
||||||
|
#
|
||||||
|
# myloss.py : implementation of the Dice coeff and the associated loss
|
||||||
|
#
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.modules.loss import _Loss
|
|
||||||
from torch.autograd import Function
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from torch.nn.modules.loss import _Loss
|
||||||
|
from torch.autograd import Function, Variable
|
||||||
|
|
||||||
|
|
||||||
class DiceCoeff(Function):
|
class DiceCoeff(Function):
|
||||||
|
"""Dice coeff for individual examples"""
|
||||||
|
def forward(self, input, target):
|
||||||
|
self.save_for_backward(input, target)
|
||||||
|
self.inter = torch.dot(input, target) + 0.0001
|
||||||
|
self.union = torch.sum(input) + torch.sum(target) + 0.0001
|
||||||
|
|
||||||
def forward(ctx, input, target):
|
t = 2*self.inter.float()/self.union.float()
|
||||||
ctx.save_for_backward(input, target)
|
|
||||||
ctx.inter = torch.dot(input, target) + 0.0001
|
|
||||||
ctx.union = torch.sum(input) + torch.sum(target) + 0.0001
|
|
||||||
|
|
||||||
t = 2*ctx.inter.float()/ctx.union.float()
|
|
||||||
return t
|
return t
|
||||||
|
|
||||||
# This function has only a single output, so it gets only one gradient
|
# This function has only a single output, so it gets only one gradient
|
||||||
def backward(ctx, grad_output):
|
def backward(self, grad_output):
|
||||||
|
|
||||||
input, target = ctx.saved_variables
|
input, target = self.saved_variables
|
||||||
grad_input = grad_target = None
|
grad_input = grad_target = None
|
||||||
|
|
||||||
if self.needs_input_grad[0]:
|
if self.needs_input_grad[0]:
|
||||||
grad_input = grad_output * 2 * (target * ctx.union + ctx.inter) \
|
grad_input = grad_output * 2 * (target * self.union + self.inter) \
|
||||||
/ ctx.union * ctx.union
|
/ self.union * self.union
|
||||||
if self.needs_input_grad[1]:
|
if self.needs_input_grad[1]:
|
||||||
grad_target = None
|
grad_target = None
|
||||||
|
|
||||||
return grad_input, grad_target
|
return grad_input, grad_target
|
||||||
|
|
||||||
|
|
||||||
def dice_coeff(input, target):
|
def dice_coeff(input, target):
|
||||||
return DiceCoeff().forward(input, target)
|
"""Dice coeff for batches"""
|
||||||
|
if input.is_cuda:
|
||||||
|
s = Variable(torch.FloatTensor(1).cuda().zero_())
|
||||||
|
else:
|
||||||
|
s = Variable(torch.FloatTensor(1).zero_())
|
||||||
|
|
||||||
|
for i, c in enumerate(zip(input, target)):
|
||||||
|
s = s + DiceCoeff().forward(c[0], c[1])
|
||||||
|
|
||||||
|
return s / (i+1)
|
||||||
|
|
||||||
|
|
||||||
class DiceLoss(_Loss):
|
class DiceLoss(_Loss):
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
|
|
105
train.py
105
train.py
|
@ -0,0 +1,105 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from load import *
|
||||||
|
from data_vis import *
|
||||||
|
from utils import split_train_val, batch
|
||||||
|
from myloss import DiceLoss
|
||||||
|
from unet_model import UNet
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from torch import optim
|
||||||
|
from optparse import OptionParser
|
||||||
|
|
||||||
|
|
||||||
|
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
|
||||||
|
cp=True, gpu=False):
|
||||||
|
dir_img = 'data/train/'
|
||||||
|
dir_mask = 'data/train_masks/'
|
||||||
|
dir_checkpoint = 'checkpoints/'
|
||||||
|
|
||||||
|
# get ids
|
||||||
|
ids = get_ids(dir_img)
|
||||||
|
ids = split_ids(ids)
|
||||||
|
|
||||||
|
iddataset = split_train_val(ids, val_percent)
|
||||||
|
|
||||||
|
print('''
|
||||||
|
Starting training:
|
||||||
|
Epochs: {}
|
||||||
|
Batch size: {}
|
||||||
|
Learning rate: {}
|
||||||
|
Training size: {}
|
||||||
|
Validation size: {}
|
||||||
|
Checkpoints: {}
|
||||||
|
CUDA: {}
|
||||||
|
'''.format(epochs, batch_size, lr, len(iddataset['train']),
|
||||||
|
len(iddataset['val']), str(cp), str(gpu)))
|
||||||
|
|
||||||
|
N_train = len(iddataset['train'])
|
||||||
|
|
||||||
|
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
|
||||||
|
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask)
|
||||||
|
|
||||||
|
optimizer = optim.Adam(net.parameters(), lr=lr)
|
||||||
|
criterion = DiceLoss()
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
print('Starting epoch {}/{}.'.format(epoch+1, epochs))
|
||||||
|
|
||||||
|
epoch_loss = 0
|
||||||
|
|
||||||
|
for i, b in enumerate(batch(train, batch_size)):
|
||||||
|
X = np.array([i[0] for i in b])
|
||||||
|
y = np.array([i[1] for i in b])
|
||||||
|
|
||||||
|
X = torch.FloatTensor(X)
|
||||||
|
y = torch.ByteTensor(y)
|
||||||
|
|
||||||
|
if gpu:
|
||||||
|
X = Variable(X).cuda()
|
||||||
|
y = Variable(y).cuda()
|
||||||
|
else:
|
||||||
|
X = Variable(X)
|
||||||
|
y = Variable(y)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
y_pred = net(X)
|
||||||
|
|
||||||
|
loss = criterion(y_pred, y.float())
|
||||||
|
epoch_loss += loss.data[0]
|
||||||
|
|
||||||
|
print('{0:.4f} --- loss: {1:.6f}'.format(i*batch_size/N_train,
|
||||||
|
loss.data[0]))
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
print('Epoch finished ! Loss: {}'.format(epoch_loss/i))
|
||||||
|
|
||||||
|
if cp:
|
||||||
|
torch.save(net.state_dict(),
|
||||||
|
dir_checkpoint + 'CP{}.pth'.format(epoch+1))
|
||||||
|
|
||||||
|
print('Checkpoint {} saved !'.format(epoch+1))
|
||||||
|
|
||||||
|
|
||||||
|
parser = OptionParser()
|
||||||
|
parser.add_option("-e", "--epochs", dest="epochs", default=5, type="int",
|
||||||
|
help="number of epochs")
|
||||||
|
parser.add_option("-b", "--batch-size", dest="batchsize", default=10,
|
||||||
|
type="int", help="batch size")
|
||||||
|
parser.add_option("-l", "--learning-rate", dest="lr", default=0.1,
|
||||||
|
type="int", help="learning rate")
|
||||||
|
parser.add_option("-g", "--gpu", action="store_true", dest="gpu",
|
||||||
|
default=False, help="use cuda")
|
||||||
|
parser.add_option("-n", "--ngpu", action="store_false", dest="gpu",
|
||||||
|
default=False, help="use cuda")
|
||||||
|
|
||||||
|
|
||||||
|
(options, args) = parser.parse_args()
|
||||||
|
|
||||||
|
net = UNet(3, 1)
|
||||||
|
if options.gpu:
|
||||||
|
net.cuda()
|
||||||
|
|
||||||
|
train_net(net, options.epochs, options.batchsize, options.lr, gpu=options.gpu)
|
|
@ -4,6 +4,7 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
from unet_parts import *
|
from unet_parts import *
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
def __init__(self, n_channels, n_classes):
|
def __init__(self, n_channels, n_classes):
|
||||||
super(UNet, self).__init__()
|
super(UNet, self).__init__()
|
||||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class double_conv(nn.Module):
|
class double_conv(nn.Module):
|
||||||
def __init__(self, in_ch, out_ch):
|
def __init__(self, in_ch, out_ch):
|
||||||
super(double_conv, self).__init__()
|
super(double_conv, self).__init__()
|
||||||
|
@ -13,10 +14,12 @@ class double_conv(nn.Module):
|
||||||
nn.Conv2d(out_ch, out_ch, 3, padding=1),
|
nn.Conv2d(out_ch, out_ch, 3, padding=1),
|
||||||
nn.ReLU()
|
nn.ReLU()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class inconv(nn.Module):
|
class inconv(nn.Module):
|
||||||
def __init__(self, in_ch, out_ch):
|
def __init__(self, in_ch, out_ch):
|
||||||
super(inconv, self).__init__()
|
super(inconv, self).__init__()
|
||||||
|
@ -26,6 +29,7 @@ class inconv(nn.Module):
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class down(nn.Module):
|
class down(nn.Module):
|
||||||
def __init__(self, in_ch, out_ch):
|
def __init__(self, in_ch, out_ch):
|
||||||
super(down, self).__init__()
|
super(down, self).__init__()
|
||||||
|
@ -38,15 +42,15 @@ class down(nn.Module):
|
||||||
x = self.mpconv(x)
|
x = self.mpconv(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class up(nn.Module):
|
class up(nn.Module):
|
||||||
def __init__(self, in_ch, out_ch):
|
def __init__(self, in_ch, out_ch):
|
||||||
super(up, self).__init__()
|
super(up, self).__init__()
|
||||||
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
|
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
|
||||||
#self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
|
# self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
|
||||||
self.conv = double_conv(in_ch, out_ch)
|
self.conv = double_conv(in_ch, out_ch)
|
||||||
|
|
||||||
def forward(self, x1, x2):
|
def forward(self, x1, x2):
|
||||||
|
|
||||||
x1 = self.up(x1)
|
x1 = self.up(x1)
|
||||||
diffX = x1.size()[2] - x2.size()[2]
|
diffX = x1.size()[2] - x2.size()[2]
|
||||||
diffY = x1.size()[3] - x2.size()[3]
|
diffY = x1.size()[3] - x2.size()[3]
|
||||||
|
@ -56,6 +60,7 @@ class up(nn.Module):
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class outconv(nn.Module):
|
class outconv(nn.Module):
|
||||||
def __init__(self, in_ch, out_ch):
|
def __init__(self, in_ch, out_ch):
|
||||||
super(outconv, self).__init__()
|
super(outconv, self).__init__()
|
||||||
|
|
42
utils.py
42
utils.py
|
@ -1,26 +1,54 @@
|
||||||
import PIL
|
import PIL
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
def get_square(img, pos):
|
def get_square(img, pos):
|
||||||
"""Extract a left or a right square from PILimg"""
|
"""Extract a left or a right square from PILimg shape : (H, W, C))"""
|
||||||
"""shape : (H, W, C))"""
|
|
||||||
img = np.array(img)
|
img = np.array(img)
|
||||||
|
|
||||||
h = img.shape[0]
|
h = img.shape[0]
|
||||||
w = img.shape[1]
|
|
||||||
|
|
||||||
if pos == 0:
|
if pos == 0:
|
||||||
return img[:, :h]
|
return img[:, :h]
|
||||||
else:
|
else:
|
||||||
return img[:, -h:]
|
return img[:, -h:]
|
||||||
|
|
||||||
def resize_and_crop(pilimg, scale=0.5, final_height=640):
|
|
||||||
|
def resize_and_crop(pilimg, scale=0.2, final_height=None):
|
||||||
w = pilimg.size[0]
|
w = pilimg.size[0]
|
||||||
h = pilimg.size[1]
|
h = pilimg.size[1]
|
||||||
newW = int(w * scale)
|
newW = int(w * scale)
|
||||||
newH = int(h * scale)
|
newH = int(h * scale)
|
||||||
diff = newH - final_height
|
|
||||||
|
if not final_height:
|
||||||
|
diff = 0
|
||||||
|
else:
|
||||||
|
diff = newH - final_height
|
||||||
|
|
||||||
img = pilimg.resize((newW, newH))
|
img = pilimg.resize((newW, newH))
|
||||||
img = img.crop((0, diff // 2, newW, newH - diff // 2))
|
img = img.crop((0, diff // 2, newW, newH - diff // 2))
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def batch(iterable, batch_size):
|
||||||
|
"""Yields lists by batch"""
|
||||||
|
b = []
|
||||||
|
for i, t in enumerate(iterable):
|
||||||
|
b.append(t)
|
||||||
|
if (i+1) % batch_size == 0:
|
||||||
|
yield b
|
||||||
|
b = []
|
||||||
|
|
||||||
|
if len(b) > 0:
|
||||||
|
yield b
|
||||||
|
|
||||||
|
|
||||||
|
def split_train_val(dataset, val_percent=0.05):
|
||||||
|
dataset = list(dataset)
|
||||||
|
length = len(dataset)
|
||||||
|
n = int(length * val_percent)
|
||||||
|
random.shuffle(dataset)
|
||||||
|
return {'train': dataset[:-n], 'val': dataset[-n:]}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(x):
|
||||||
|
return x / 255
|
||||||
|
|
Loading…
Reference in a new issue