mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
107 lines
2.3 KiB
Python
107 lines
2.3 KiB
Python
#models
|
|
from unet_model import UNet
|
|
from myloss import *
|
|
import torch
|
|
from torch.autograd import Variable
|
|
from torch import optim
|
|
|
|
#data manipulation
|
|
import numpy as np
|
|
import pandas as pd
|
|
import PIL
|
|
|
|
#load files
|
|
import os
|
|
|
|
#data visualization
|
|
from data_vis import plot_img_mask
|
|
from utils import *
|
|
import matplotlib.pyplot as plt
|
|
|
|
#quit after interrupt
|
|
import sys
|
|
|
|
|
|
|
|
dir = 'data'
|
|
ids = []
|
|
|
|
for f in os.listdir(dir + '/train'):
|
|
id = f[:-4]
|
|
ids.append([id, 0])
|
|
ids.append([id, 1])
|
|
|
|
np.random.shuffle(ids)
|
|
#%%
|
|
|
|
|
|
net = UNet(3, 1)
|
|
net.cuda()
|
|
|
|
def train(net):
|
|
optimizer = optim.Adam(net.parameters(), lr=1)
|
|
criterion = DiceLoss()
|
|
|
|
epochs = 5
|
|
for epoch in range(epochs):
|
|
print('epoch {}/{}...'.format(epoch+1, epochs))
|
|
l = 0
|
|
|
|
for i, c in enumerate(ids):
|
|
id = c[0]
|
|
pos = c[1]
|
|
im = PIL.Image.open(dir + '/train/' + id + '.jpg')
|
|
im = resize_and_crop(im)
|
|
|
|
ma = PIL.Image.open(dir + '/train_masks/' + id + '_mask.gif')
|
|
ma = resize_and_crop(ma)
|
|
|
|
left, right = split_into_squares(np.array(im))
|
|
left_m, right_m = split_into_squares(np.array(ma))
|
|
|
|
if pos == 0:
|
|
X = left
|
|
y = left_m
|
|
else:
|
|
X = right
|
|
y = right_m
|
|
|
|
|
|
X = np.transpose(X, axes=[2, 0, 1])
|
|
X = torch.FloatTensor(X / 255).unsqueeze(0).cuda()
|
|
y = Variable(torch.ByteTensor(y)).cuda()
|
|
|
|
X = Variable(X).cuda()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
y_pred = net(X).squeeze(1)
|
|
|
|
|
|
loss = criterion(y_pred, y.unsqueeze(0).float())
|
|
|
|
l += loss.data[0]
|
|
loss.backward()
|
|
if i%10 == 0:
|
|
optimizer.step()
|
|
print('Stepped')
|
|
|
|
print('{0:.4f}%\t\t{1:.6f}'.format(i/len(ids)*100, loss.data[0]))
|
|
|
|
l = l / len(ids)
|
|
print('Loss : {}'.format(l))
|
|
torch.save(net.state_dict(), 'MODEL_EPOCH{}_LOSS{}.pth'.format(epoch+1, l))
|
|
print('Saved')
|
|
|
|
try:
|
|
net.load_state_dict(torch.load('MODEL_INTERRUPTED.pth'))
|
|
train(net)
|
|
|
|
except KeyboardInterrupt:
|
|
print('Interrupted')
|
|
torch.save(net.state_dict(), 'MODEL_INTERRUPTED.pth')
|
|
try:
|
|
sys.exit(0)
|
|
except SystemExit:
|
|
os._exit(0)
|