First prototype
slow, no gpu, no validation
This commit is contained in:
parent
415d600d3a
commit
c8c82204bf
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
*.pyc
|
||||
data/
|
||||
__pycache__/
|
12
data_vis.py
Normal file
12
data_vis.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
import matplotlib.pyplot as plt
|
||||
|
||||
def plot_img_mask(img, mask):
|
||||
fig = plt.figure()
|
||||
|
||||
ax1 = fig.add_subplot(1, 3, 1)
|
||||
ax1.imshow(img)
|
||||
|
||||
ax2 = fig.add_subplot(1, 3, 2)
|
||||
ax2.imshow(mask)
|
||||
|
||||
plt.show()
|
101
main.py
Normal file
101
main.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
#models
|
||||
from unet_model import UNet
|
||||
from myloss import *
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from torch import optim
|
||||
|
||||
#data manipulation
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import cv2
|
||||
import PIL
|
||||
|
||||
#load files
|
||||
import os
|
||||
|
||||
#data vis
|
||||
from data_vis import plot_img_mask
|
||||
from utils import *
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
dir = 'data'
|
||||
ids = []
|
||||
|
||||
for f in os.listdir(dir + '/train'):
|
||||
id = f[:-4]
|
||||
ids.append([id, 0])
|
||||
ids.append([id, 1])
|
||||
|
||||
np.random.shuffle(ids)
|
||||
#%%
|
||||
|
||||
|
||||
net = UNet(3, 1)
|
||||
|
||||
optimizer = optim.Adam(net.parameters(), lr=0.001)
|
||||
criterion = DiceLoss()
|
||||
|
||||
dataset = []
|
||||
epochs = 5
|
||||
for epoch in range(epochs):
|
||||
print('epoch {}/{}...'.format(epoch+1, epochs))
|
||||
l = 0
|
||||
|
||||
for i, c in enumerate(ids):
|
||||
id = c[0]
|
||||
pos = c[1]
|
||||
im = PIL.Image.open(dir + '/train/' + id + '.jpg')
|
||||
im = resize_and_crop(im)
|
||||
|
||||
ma = PIL.Image.open(dir + '/train_masks/' + id + '_mask.gif')
|
||||
ma = resize_and_crop(ma)
|
||||
|
||||
left, right = split_into_squares(np.array(im))
|
||||
left_m, right_m = split_into_squares(np.array(ma))
|
||||
|
||||
if pos == 0:
|
||||
X = left
|
||||
y = left_m
|
||||
else:
|
||||
X = right
|
||||
y = right_m
|
||||
|
||||
|
||||
X = np.transpose(X, axes=[2, 0, 1])
|
||||
X = torch.FloatTensor(X / 255).unsqueeze(0)
|
||||
y = Variable(torch.ByteTensor(y))
|
||||
|
||||
X = Variable(X, requires_grad=False)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
y_pred = net(X).squeeze(1)
|
||||
|
||||
|
||||
loss = criterion(y_pred, y.unsqueeze(0).float())
|
||||
|
||||
l += loss.data[0]
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print('{0:.4f}%.'.format(i/len(ids)*100, end=''))
|
||||
|
||||
print('Loss : {}'.format(l))
|
||||
|
||||
|
||||
#%%
|
||||
|
||||
|
||||
|
||||
|
||||
#net = UNet(3, 2)
|
||||
|
||||
#x = Variable(torch.FloatTensor(np.random.randn(1, 3, 640, 640)))
|
||||
|
||||
#y = net(x)
|
||||
|
||||
|
||||
#plt.imshow(y[0])
|
||||
#plt.show()
|
35
myloss.py
Normal file
35
myloss.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import torch
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.autograd import Function
|
||||
import torch.nn.functional as F
|
||||
|
||||
class DiceCoeff(Function):
|
||||
|
||||
def forward(ctx, input, target):
|
||||
ctx.save_for_backward(input, target)
|
||||
ctx.inter = torch.dot(input, target) + 0.0001
|
||||
ctx.union = torch.sum(input) + torch.sum(target) + 0.0001
|
||||
|
||||
t = 2*ctx.inter.float()/ctx.union.float()
|
||||
return t
|
||||
|
||||
# This function has only a single output, so it gets only one gradient
|
||||
def backward(ctx, grad_output):
|
||||
|
||||
input, target = ctx.saved_variables
|
||||
grad_input = grad_target = None
|
||||
|
||||
if self.needs_input_grad[0]:
|
||||
grad_input = grad_output * 2 * (target * ctx.union + ctx.inter) \
|
||||
/ ctx.union * ctx.union
|
||||
if self.needs_input_grad[1]:
|
||||
grad_target = None
|
||||
|
||||
return grad_input, grad_target
|
||||
|
||||
def dice_coeff(input, target):
|
||||
return DiceCoeff().forward(input, target)
|
||||
|
||||
class DiceLoss(_Loss):
|
||||
def forward(self, input, target):
|
||||
return 1 - dice_coeff(F.sigmoid(input), target)
|
32
unet_model.py
Normal file
32
unet_model.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from unet_parts import *
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, n_channels, n_classes):
|
||||
super(UNet, self).__init__()
|
||||
self.inc = inconv(n_channels, 64)
|
||||
self.down1 = down(64, 128)
|
||||
self.down2 = down(128, 256)
|
||||
self.down3 = down(256, 512)
|
||||
self.down4 = down(512, 1024)
|
||||
self.up1 = up(1024, 512)
|
||||
self.up2 = up(512, 256)
|
||||
self.up3 = up(256, 128)
|
||||
self.up4 = up(128, 64)
|
||||
self.outc = outconv(64, n_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.inc(x)
|
||||
x2 = self.down1(x1)
|
||||
x3 = self.down2(x2)
|
||||
x4 = self.down3(x3)
|
||||
x5 = self.down4(x4)
|
||||
x = self.up1(x5, x4)
|
||||
x = self.up2(x, x3)
|
||||
x = self.up3(x, x2)
|
||||
x = self.up4(x, x1)
|
||||
x = self.outc(x)
|
||||
return x
|
65
unet_parts.py
Normal file
65
unet_parts.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
# sub-parts of the U-Net model
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class double_conv(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(double_conv, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(out_ch, out_ch, 3, padding=1),
|
||||
nn.ReLU()
|
||||
)
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class inconv(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(inconv, self).__init__()
|
||||
self.conv = double_conv(in_ch, out_ch)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class down(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(down, self).__init__()
|
||||
self.mpconv = nn.Sequential(
|
||||
nn.MaxPool2d(2),
|
||||
double_conv(in_ch, out_ch)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.mpconv(x)
|
||||
return x
|
||||
|
||||
class up(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(up, self).__init__()
|
||||
self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
|
||||
self.conv = double_conv(in_ch, out_ch)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
|
||||
x1 = self.up(x1)
|
||||
diffX = x1.size()[2] - x2.size()[2]
|
||||
diffY = x1.size()[3] - x2.size()[3]
|
||||
x2 = F.pad(x2, (diffX // 2, int(diffX / 2),
|
||||
diffY // 2, int(diffY / 2)))
|
||||
x = torch.cat([x2, x1], dim=1)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class outconv(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(outconv, self).__init__()
|
||||
self.conv = nn.Conv2d(in_ch, out_ch, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
24
utils.py
Normal file
24
utils.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
import PIL
|
||||
|
||||
def split_into_squares(img):
|
||||
"""Extract a left and a right square from ndarray"""
|
||||
"""shape : (H, W, C))"""
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
|
||||
|
||||
left = img[:, :h]
|
||||
right = img[:, -h:]
|
||||
|
||||
return left, right
|
||||
|
||||
def resize_and_crop(pilimg, scale=0.5, final_height=640):
|
||||
w = pilimg.size[0]
|
||||
h = pilimg.size[1]
|
||||
newW = int(w * scale)
|
||||
newH = int(h * scale)
|
||||
diff = newH - final_height
|
||||
|
||||
img = pilimg.resize((newW, newH))
|
||||
img = img.crop((0, diff // 2, newW, newH - diff // 2))
|
||||
return img
|
Loading…
Reference in a new issue