First prototype

slow, no gpu, no validation
This commit is contained in:
milesial 2017-08-16 14:24:29 +02:00
parent 415d600d3a
commit c8c82204bf
8 changed files with 272 additions and 0 deletions

3
.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
*.pyc
data/
__pycache__/

12
data_vis.py Normal file
View file

@ -0,0 +1,12 @@
import matplotlib.pyplot as plt
def plot_img_mask(img, mask):
fig = plt.figure()
ax1 = fig.add_subplot(1, 3, 1)
ax1.imshow(img)
ax2 = fig.add_subplot(1, 3, 2)
ax2.imshow(mask)
plt.show()

101
main.py Normal file
View file

@ -0,0 +1,101 @@
#models
from unet_model import UNet
from myloss import *
import torch
from torch.autograd import Variable
from torch import optim
#data manipulation
import numpy as np
import pandas as pd
import cv2
import PIL
#load files
import os
#data vis
from data_vis import plot_img_mask
from utils import *
import matplotlib.pyplot as plt
dir = 'data'
ids = []
for f in os.listdir(dir + '/train'):
id = f[:-4]
ids.append([id, 0])
ids.append([id, 1])
np.random.shuffle(ids)
#%%
net = UNet(3, 1)
optimizer = optim.Adam(net.parameters(), lr=0.001)
criterion = DiceLoss()
dataset = []
epochs = 5
for epoch in range(epochs):
print('epoch {}/{}...'.format(epoch+1, epochs))
l = 0
for i, c in enumerate(ids):
id = c[0]
pos = c[1]
im = PIL.Image.open(dir + '/train/' + id + '.jpg')
im = resize_and_crop(im)
ma = PIL.Image.open(dir + '/train_masks/' + id + '_mask.gif')
ma = resize_and_crop(ma)
left, right = split_into_squares(np.array(im))
left_m, right_m = split_into_squares(np.array(ma))
if pos == 0:
X = left
y = left_m
else:
X = right
y = right_m
X = np.transpose(X, axes=[2, 0, 1])
X = torch.FloatTensor(X / 255).unsqueeze(0)
y = Variable(torch.ByteTensor(y))
X = Variable(X, requires_grad=False)
optimizer.zero_grad()
y_pred = net(X).squeeze(1)
loss = criterion(y_pred, y.unsqueeze(0).float())
l += loss.data[0]
loss.backward()
optimizer.step()
print('{0:.4f}%.'.format(i/len(ids)*100, end=''))
print('Loss : {}'.format(l))
#%%
#net = UNet(3, 2)
#x = Variable(torch.FloatTensor(np.random.randn(1, 3, 640, 640)))
#y = net(x)
#plt.imshow(y[0])
#plt.show()

35
myloss.py Normal file
View file

@ -0,0 +1,35 @@
import torch
from torch.nn.modules.loss import _Loss
from torch.autograd import Function
import torch.nn.functional as F
class DiceCoeff(Function):
def forward(ctx, input, target):
ctx.save_for_backward(input, target)
ctx.inter = torch.dot(input, target) + 0.0001
ctx.union = torch.sum(input) + torch.sum(target) + 0.0001
t = 2*ctx.inter.float()/ctx.union.float()
return t
# This function has only a single output, so it gets only one gradient
def backward(ctx, grad_output):
input, target = ctx.saved_variables
grad_input = grad_target = None
if self.needs_input_grad[0]:
grad_input = grad_output * 2 * (target * ctx.union + ctx.inter) \
/ ctx.union * ctx.union
if self.needs_input_grad[1]:
grad_target = None
return grad_input, grad_target
def dice_coeff(input, target):
return DiceCoeff().forward(input, target)
class DiceLoss(_Loss):
def forward(self, input, target):
return 1 - dice_coeff(F.sigmoid(input), target)

0
train.py Normal file
View file

32
unet_model.py Normal file
View file

@ -0,0 +1,32 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from unet_parts import *
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = inconv(n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 1024)
self.up1 = up(1024, 512)
self.up2 = up(512, 256)
self.up3 = up(256, 128)
self.up4 = up(128, 64)
self.outc = outconv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return x

65
unet_parts.py Normal file
View file

@ -0,0 +1,65 @@
# sub-parts of the U-Net model
import torch
import torch.nn as nn
import torch.nn.functional as F
class double_conv(nn.Module):
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.ReLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.ReLU()
)
def forward(self, x):
x = self.conv(x)
return x
class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = double_conv(in_ch, out_ch)
def forward(self, x):
x = self.conv(x)
return x
class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch)
)
def forward(self, x):
x = self.mpconv(x)
return x
class up(nn.Module):
def __init__(self, in_ch, out_ch):
super(up, self).__init__()
self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
self.conv = double_conv(in_ch, out_ch)
def forward(self, x1, x2):
x1 = self.up(x1)
diffX = x1.size()[2] - x2.size()[2]
diffY = x1.size()[3] - x2.size()[3]
x2 = F.pad(x2, (diffX // 2, int(diffX / 2),
diffY // 2, int(diffY / 2)))
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
x = self.conv(x)
return x

24
utils.py Normal file
View file

@ -0,0 +1,24 @@
import PIL
def split_into_squares(img):
"""Extract a left and a right square from ndarray"""
"""shape : (H, W, C))"""
h = img.shape[0]
w = img.shape[1]
left = img[:, :h]
right = img[:, -h:]
return left, right
def resize_and_crop(pilimg, scale=0.5, final_height=640):
w = pilimg.size[0]
h = pilimg.size[1]
newW = int(w * scale)
newH = int(h * scale)
diff = newH - final_height
img = pilimg.resize((newW, newH))
img = img.crop((0, diff // 2, newW, newH - diff // 2))
return img