mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 23:12:05 +00:00
First prototype
slow, no gpu, no validation
This commit is contained in:
parent
415d600d3a
commit
c8c82204bf
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
*.pyc
|
||||||
|
data/
|
||||||
|
__pycache__/
|
12
data_vis.py
Normal file
12
data_vis.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
def plot_img_mask(img, mask):
|
||||||
|
fig = plt.figure()
|
||||||
|
|
||||||
|
ax1 = fig.add_subplot(1, 3, 1)
|
||||||
|
ax1.imshow(img)
|
||||||
|
|
||||||
|
ax2 = fig.add_subplot(1, 3, 2)
|
||||||
|
ax2.imshow(mask)
|
||||||
|
|
||||||
|
plt.show()
|
101
main.py
Normal file
101
main.py
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
#models
|
||||||
|
from unet_model import UNet
|
||||||
|
from myloss import *
|
||||||
|
import torch
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from torch import optim
|
||||||
|
|
||||||
|
#data manipulation
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import cv2
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
#load files
|
||||||
|
import os
|
||||||
|
|
||||||
|
#data vis
|
||||||
|
from data_vis import plot_img_mask
|
||||||
|
from utils import *
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
dir = 'data'
|
||||||
|
ids = []
|
||||||
|
|
||||||
|
for f in os.listdir(dir + '/train'):
|
||||||
|
id = f[:-4]
|
||||||
|
ids.append([id, 0])
|
||||||
|
ids.append([id, 1])
|
||||||
|
|
||||||
|
np.random.shuffle(ids)
|
||||||
|
#%%
|
||||||
|
|
||||||
|
|
||||||
|
net = UNet(3, 1)
|
||||||
|
|
||||||
|
optimizer = optim.Adam(net.parameters(), lr=0.001)
|
||||||
|
criterion = DiceLoss()
|
||||||
|
|
||||||
|
dataset = []
|
||||||
|
epochs = 5
|
||||||
|
for epoch in range(epochs):
|
||||||
|
print('epoch {}/{}...'.format(epoch+1, epochs))
|
||||||
|
l = 0
|
||||||
|
|
||||||
|
for i, c in enumerate(ids):
|
||||||
|
id = c[0]
|
||||||
|
pos = c[1]
|
||||||
|
im = PIL.Image.open(dir + '/train/' + id + '.jpg')
|
||||||
|
im = resize_and_crop(im)
|
||||||
|
|
||||||
|
ma = PIL.Image.open(dir + '/train_masks/' + id + '_mask.gif')
|
||||||
|
ma = resize_and_crop(ma)
|
||||||
|
|
||||||
|
left, right = split_into_squares(np.array(im))
|
||||||
|
left_m, right_m = split_into_squares(np.array(ma))
|
||||||
|
|
||||||
|
if pos == 0:
|
||||||
|
X = left
|
||||||
|
y = left_m
|
||||||
|
else:
|
||||||
|
X = right
|
||||||
|
y = right_m
|
||||||
|
|
||||||
|
|
||||||
|
X = np.transpose(X, axes=[2, 0, 1])
|
||||||
|
X = torch.FloatTensor(X / 255).unsqueeze(0)
|
||||||
|
y = Variable(torch.ByteTensor(y))
|
||||||
|
|
||||||
|
X = Variable(X, requires_grad=False)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
y_pred = net(X).squeeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
loss = criterion(y_pred, y.unsqueeze(0).float())
|
||||||
|
|
||||||
|
l += loss.data[0]
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
print('{0:.4f}%.'.format(i/len(ids)*100, end=''))
|
||||||
|
|
||||||
|
print('Loss : {}'.format(l))
|
||||||
|
|
||||||
|
|
||||||
|
#%%
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#net = UNet(3, 2)
|
||||||
|
|
||||||
|
#x = Variable(torch.FloatTensor(np.random.randn(1, 3, 640, 640)))
|
||||||
|
|
||||||
|
#y = net(x)
|
||||||
|
|
||||||
|
|
||||||
|
#plt.imshow(y[0])
|
||||||
|
#plt.show()
|
35
myloss.py
Normal file
35
myloss.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
import torch
|
||||||
|
from torch.nn.modules.loss import _Loss
|
||||||
|
from torch.autograd import Function
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class DiceCoeff(Function):
|
||||||
|
|
||||||
|
def forward(ctx, input, target):
|
||||||
|
ctx.save_for_backward(input, target)
|
||||||
|
ctx.inter = torch.dot(input, target) + 0.0001
|
||||||
|
ctx.union = torch.sum(input) + torch.sum(target) + 0.0001
|
||||||
|
|
||||||
|
t = 2*ctx.inter.float()/ctx.union.float()
|
||||||
|
return t
|
||||||
|
|
||||||
|
# This function has only a single output, so it gets only one gradient
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
|
||||||
|
input, target = ctx.saved_variables
|
||||||
|
grad_input = grad_target = None
|
||||||
|
|
||||||
|
if self.needs_input_grad[0]:
|
||||||
|
grad_input = grad_output * 2 * (target * ctx.union + ctx.inter) \
|
||||||
|
/ ctx.union * ctx.union
|
||||||
|
if self.needs_input_grad[1]:
|
||||||
|
grad_target = None
|
||||||
|
|
||||||
|
return grad_input, grad_target
|
||||||
|
|
||||||
|
def dice_coeff(input, target):
|
||||||
|
return DiceCoeff().forward(input, target)
|
||||||
|
|
||||||
|
class DiceLoss(_Loss):
|
||||||
|
def forward(self, input, target):
|
||||||
|
return 1 - dice_coeff(F.sigmoid(input), target)
|
32
unet_model.py
Normal file
32
unet_model.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from unet_parts import *
|
||||||
|
|
||||||
|
class UNet(nn.Module):
|
||||||
|
def __init__(self, n_channels, n_classes):
|
||||||
|
super(UNet, self).__init__()
|
||||||
|
self.inc = inconv(n_channels, 64)
|
||||||
|
self.down1 = down(64, 128)
|
||||||
|
self.down2 = down(128, 256)
|
||||||
|
self.down3 = down(256, 512)
|
||||||
|
self.down4 = down(512, 1024)
|
||||||
|
self.up1 = up(1024, 512)
|
||||||
|
self.up2 = up(512, 256)
|
||||||
|
self.up3 = up(256, 128)
|
||||||
|
self.up4 = up(128, 64)
|
||||||
|
self.outc = outconv(64, n_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x1 = self.inc(x)
|
||||||
|
x2 = self.down1(x1)
|
||||||
|
x3 = self.down2(x2)
|
||||||
|
x4 = self.down3(x3)
|
||||||
|
x5 = self.down4(x4)
|
||||||
|
x = self.up1(x5, x4)
|
||||||
|
x = self.up2(x, x3)
|
||||||
|
x = self.up3(x, x2)
|
||||||
|
x = self.up4(x, x1)
|
||||||
|
x = self.outc(x)
|
||||||
|
return x
|
65
unet_parts.py
Normal file
65
unet_parts.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
# sub-parts of the U-Net model
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class double_conv(nn.Module):
|
||||||
|
def __init__(self, in_ch, out_ch):
|
||||||
|
super(double_conv, self).__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(in_ch, out_ch, 3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(out_ch, out_ch, 3, padding=1),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class inconv(nn.Module):
|
||||||
|
def __init__(self, in_ch, out_ch):
|
||||||
|
super(inconv, self).__init__()
|
||||||
|
self.conv = double_conv(in_ch, out_ch)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class down(nn.Module):
|
||||||
|
def __init__(self, in_ch, out_ch):
|
||||||
|
super(down, self).__init__()
|
||||||
|
self.mpconv = nn.Sequential(
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
double_conv(in_ch, out_ch)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.mpconv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class up(nn.Module):
|
||||||
|
def __init__(self, in_ch, out_ch):
|
||||||
|
super(up, self).__init__()
|
||||||
|
self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
|
||||||
|
self.conv = double_conv(in_ch, out_ch)
|
||||||
|
|
||||||
|
def forward(self, x1, x2):
|
||||||
|
|
||||||
|
x1 = self.up(x1)
|
||||||
|
diffX = x1.size()[2] - x2.size()[2]
|
||||||
|
diffY = x1.size()[3] - x2.size()[3]
|
||||||
|
x2 = F.pad(x2, (diffX // 2, int(diffX / 2),
|
||||||
|
diffY // 2, int(diffY / 2)))
|
||||||
|
x = torch.cat([x2, x1], dim=1)
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class outconv(nn.Module):
|
||||||
|
def __init__(self, in_ch, out_ch):
|
||||||
|
super(outconv, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(in_ch, out_ch, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
24
utils.py
Normal file
24
utils.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
def split_into_squares(img):
|
||||||
|
"""Extract a left and a right square from ndarray"""
|
||||||
|
"""shape : (H, W, C))"""
|
||||||
|
h = img.shape[0]
|
||||||
|
w = img.shape[1]
|
||||||
|
|
||||||
|
|
||||||
|
left = img[:, :h]
|
||||||
|
right = img[:, -h:]
|
||||||
|
|
||||||
|
return left, right
|
||||||
|
|
||||||
|
def resize_and_crop(pilimg, scale=0.5, final_height=640):
|
||||||
|
w = pilimg.size[0]
|
||||||
|
h = pilimg.size[1]
|
||||||
|
newW = int(w * scale)
|
||||||
|
newH = int(h * scale)
|
||||||
|
diff = newH - final_height
|
||||||
|
|
||||||
|
img = pilimg.resize((newW, newH))
|
||||||
|
img = img.crop((0, diff // 2, newW, newH - diff // 2))
|
||||||
|
return img
|
Loading…
Reference in a new issue