From c8c82204bfa636ab629b9d0614a9a4a9d11bde1e Mon Sep 17 00:00:00 2001 From: milesial Date: Wed, 16 Aug 2017 14:24:29 +0200 Subject: [PATCH] First prototype slow, no gpu, no validation --- .gitignore | 3 ++ data_vis.py | 12 ++++++ main.py | 101 ++++++++++++++++++++++++++++++++++++++++++++++++++ myloss.py | 35 +++++++++++++++++ train.py | 0 unet_model.py | 32 ++++++++++++++++ unet_parts.py | 65 ++++++++++++++++++++++++++++++++ utils.py | 24 ++++++++++++ 8 files changed, 272 insertions(+) create mode 100644 .gitignore create mode 100644 data_vis.py create mode 100644 main.py create mode 100644 myloss.py create mode 100644 train.py create mode 100644 unet_model.py create mode 100644 unet_parts.py create mode 100644 utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a3f457d --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +*.pyc +data/ +__pycache__/ diff --git a/data_vis.py b/data_vis.py new file mode 100644 index 0000000..714acc5 --- /dev/null +++ b/data_vis.py @@ -0,0 +1,12 @@ +import matplotlib.pyplot as plt + +def plot_img_mask(img, mask): + fig = plt.figure() + + ax1 = fig.add_subplot(1, 3, 1) + ax1.imshow(img) + + ax2 = fig.add_subplot(1, 3, 2) + ax2.imshow(mask) + + plt.show() diff --git a/main.py b/main.py new file mode 100644 index 0000000..9df9eef --- /dev/null +++ b/main.py @@ -0,0 +1,101 @@ +#models +from unet_model import UNet +from myloss import * +import torch +from torch.autograd import Variable +from torch import optim + +#data manipulation +import numpy as np +import pandas as pd +import cv2 +import PIL + +#load files +import os + +#data vis +from data_vis import plot_img_mask +from utils import * +import matplotlib.pyplot as plt + + +dir = 'data' +ids = [] + +for f in os.listdir(dir + '/train'): + id = f[:-4] + ids.append([id, 0]) + ids.append([id, 1]) + +np.random.shuffle(ids) +#%% + + +net = UNet(3, 1) + +optimizer = optim.Adam(net.parameters(), lr=0.001) +criterion = DiceLoss() + +dataset = [] +epochs = 5 +for epoch in range(epochs): + print('epoch {}/{}...'.format(epoch+1, epochs)) + l = 0 + + for i, c in enumerate(ids): + id = c[0] + pos = c[1] + im = PIL.Image.open(dir + '/train/' + id + '.jpg') + im = resize_and_crop(im) + + ma = PIL.Image.open(dir + '/train_masks/' + id + '_mask.gif') + ma = resize_and_crop(ma) + + left, right = split_into_squares(np.array(im)) + left_m, right_m = split_into_squares(np.array(ma)) + + if pos == 0: + X = left + y = left_m + else: + X = right + y = right_m + + + X = np.transpose(X, axes=[2, 0, 1]) + X = torch.FloatTensor(X / 255).unsqueeze(0) + y = Variable(torch.ByteTensor(y)) + + X = Variable(X, requires_grad=False) + + optimizer.zero_grad() + + y_pred = net(X).squeeze(1) + + + loss = criterion(y_pred, y.unsqueeze(0).float()) + + l += loss.data[0] + loss.backward() + optimizer.step() + + print('{0:.4f}%.'.format(i/len(ids)*100, end='')) + + print('Loss : {}'.format(l)) + + +#%% + + + + +#net = UNet(3, 2) + +#x = Variable(torch.FloatTensor(np.random.randn(1, 3, 640, 640))) + +#y = net(x) + + +#plt.imshow(y[0]) +#plt.show() diff --git a/myloss.py b/myloss.py new file mode 100644 index 0000000..e65f0f1 --- /dev/null +++ b/myloss.py @@ -0,0 +1,35 @@ +import torch +from torch.nn.modules.loss import _Loss +from torch.autograd import Function +import torch.nn.functional as F + +class DiceCoeff(Function): + + def forward(ctx, input, target): + ctx.save_for_backward(input, target) + ctx.inter = torch.dot(input, target) + 0.0001 + ctx.union = torch.sum(input) + torch.sum(target) + 0.0001 + + t = 2*ctx.inter.float()/ctx.union.float() + return t + + # This function has only a single output, so it gets only one gradient + def backward(ctx, grad_output): + + input, target = ctx.saved_variables + grad_input = grad_target = None + + if self.needs_input_grad[0]: + grad_input = grad_output * 2 * (target * ctx.union + ctx.inter) \ + / ctx.union * ctx.union + if self.needs_input_grad[1]: + grad_target = None + + return grad_input, grad_target + +def dice_coeff(input, target): + return DiceCoeff().forward(input, target) + +class DiceLoss(_Loss): + def forward(self, input, target): + return 1 - dice_coeff(F.sigmoid(input), target) diff --git a/train.py b/train.py new file mode 100644 index 0000000..e69de29 diff --git a/unet_model.py b/unet_model.py new file mode 100644 index 0000000..2ed6b73 --- /dev/null +++ b/unet_model.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from unet_parts import * + +class UNet(nn.Module): + def __init__(self, n_channels, n_classes): + super(UNet, self).__init__() + self.inc = inconv(n_channels, 64) + self.down1 = down(64, 128) + self.down2 = down(128, 256) + self.down3 = down(256, 512) + self.down4 = down(512, 1024) + self.up1 = up(1024, 512) + self.up2 = up(512, 256) + self.up3 = up(256, 128) + self.up4 = up(128, 64) + self.outc = outconv(64, n_classes) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + x = self.outc(x) + return x diff --git a/unet_parts.py b/unet_parts.py new file mode 100644 index 0000000..9416dbb --- /dev/null +++ b/unet_parts.py @@ -0,0 +1,65 @@ +# sub-parts of the U-Net model + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class double_conv(nn.Module): + def __init__(self, in_ch, out_ch): + super(double_conv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, padding=1), + nn.ReLU(), + nn.Conv2d(out_ch, out_ch, 3, padding=1), + nn.ReLU() + ) + def forward(self, x): + x = self.conv(x) + return x + +class inconv(nn.Module): + def __init__(self, in_ch, out_ch): + super(inconv, self).__init__() + self.conv = double_conv(in_ch, out_ch) + + def forward(self, x): + x = self.conv(x) + return x + +class down(nn.Module): + def __init__(self, in_ch, out_ch): + super(down, self).__init__() + self.mpconv = nn.Sequential( + nn.MaxPool2d(2), + double_conv(in_ch, out_ch) + ) + + def forward(self, x): + x = self.mpconv(x) + return x + +class up(nn.Module): + def __init__(self, in_ch, out_ch): + super(up, self).__init__() + self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2) + self.conv = double_conv(in_ch, out_ch) + + def forward(self, x1, x2): + + x1 = self.up(x1) + diffX = x1.size()[2] - x2.size()[2] + diffY = x1.size()[3] - x2.size()[3] + x2 = F.pad(x2, (diffX // 2, int(diffX / 2), + diffY // 2, int(diffY / 2))) + x = torch.cat([x2, x1], dim=1) + x = self.conv(x) + return x + +class outconv(nn.Module): + def __init__(self, in_ch, out_ch): + super(outconv, self).__init__() + self.conv = nn.Conv2d(in_ch, out_ch, 1) + + def forward(self, x): + x = self.conv(x) + return x diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..2664254 --- /dev/null +++ b/utils.py @@ -0,0 +1,24 @@ +import PIL + +def split_into_squares(img): + """Extract a left and a right square from ndarray""" + """shape : (H, W, C))""" + h = img.shape[0] + w = img.shape[1] + + + left = img[:, :h] + right = img[:, -h:] + + return left, right + +def resize_and_crop(pilimg, scale=0.5, final_height=640): + w = pilimg.size[0] + h = pilimg.size[1] + newW = int(w * scale) + newH = int(h * scale) + diff = newH - final_height + + img = pilimg.resize((newW, newH)) + img = img.crop((0, diff // 2, newW, newH - diff // 2)) + return img