REVA-QCAV/main.py

106 lines
2.2 KiB
Python
Raw Normal View History

#models
from unet import UNet
from myloss import *
import torch
from torch.autograd import Variable
from torch import optim
#data manipulation
import numpy as np
import pandas as pd
import PIL
#load files
import os
2017-08-19 08:59:51 +00:00
#data visualization
from utils import *
import matplotlib.pyplot as plt
2017-08-19 08:59:51 +00:00
#quit after interrupt
import sys
dir = 'data'
ids = []
for f in os.listdir(dir + '/train'):
id = f[:-4]
ids.append([id, 0])
ids.append([id, 1])
np.random.shuffle(ids)
#%%
net = UNet(3, 1)
2017-08-19 08:59:51 +00:00
net.cuda()
2017-08-19 08:59:51 +00:00
def train(net):
optimizer = optim.Adam(net.parameters(), lr=1)
criterion = DiceLoss()
2017-08-19 08:59:51 +00:00
epochs = 5
for epoch in range(epochs):
print('epoch {}/{}...'.format(epoch+1, epochs))
l = 0
2017-08-19 08:59:51 +00:00
for i, c in enumerate(ids):
id = c[0]
pos = c[1]
im = PIL.Image.open(dir + '/train/' + id + '.jpg')
im = resize_and_crop(im)
2017-08-19 08:59:51 +00:00
ma = PIL.Image.open(dir + '/train_masks/' + id + '_mask.gif')
ma = resize_and_crop(ma)
2017-08-19 08:59:51 +00:00
left, right = split_into_squares(np.array(im))
left_m, right_m = split_into_squares(np.array(ma))
2017-08-19 08:59:51 +00:00
if pos == 0:
X = left
y = left_m
else:
X = right
y = right_m
2017-08-19 08:59:51 +00:00
X = np.transpose(X, axes=[2, 0, 1])
X = torch.FloatTensor(X / 255).unsqueeze(0).cuda()
y = Variable(torch.ByteTensor(y)).cuda()
2017-08-19 08:59:51 +00:00
X = Variable(X).cuda()
2017-08-19 08:59:51 +00:00
optimizer.zero_grad()
2017-08-19 08:59:51 +00:00
y_pred = net(X).squeeze(1)
2017-08-19 08:59:51 +00:00
loss = criterion(y_pred, y.unsqueeze(0).float())
2017-08-19 08:59:51 +00:00
l += loss.data[0]
loss.backward()
if i%10 == 0:
optimizer.step()
print('Stepped')
2017-08-19 08:59:51 +00:00
print('{0:.4f}%\t\t{1:.6f}'.format(i/len(ids)*100, loss.data[0]))
2017-08-19 08:59:51 +00:00
l = l / len(ids)
print('Loss : {}'.format(l))
torch.save(net.state_dict(), 'MODEL_EPOCH{}_LOSS{}.pth'.format(epoch+1, l))
print('Saved')
2017-08-19 08:59:51 +00:00
try:
net.load_state_dict(torch.load('MODEL_INTERRUPTED.pth'))
train(net)
2017-08-19 08:59:51 +00:00
except KeyboardInterrupt:
print('Interrupted')
torch.save(net.state_dict(), 'MODEL_INTERRUPTED.pth')
try:
sys.exit(0)
except SystemExit:
os._exit(0)