mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
Deleted main.py as it is now outdated
Former-commit-id: f57fa46f2d8cd4b0a13be0ee88708fa4dd4e0a88
This commit is contained in:
parent
617f334e06
commit
0da4ad7753
105
main.py
105
main.py
|
@ -1,105 +0,0 @@
|
||||||
#models
|
|
||||||
from unet import UNet
|
|
||||||
from myloss import *
|
|
||||||
import torch
|
|
||||||
from torch.autograd import Variable
|
|
||||||
from torch import optim
|
|
||||||
|
|
||||||
#data manipulation
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import PIL
|
|
||||||
|
|
||||||
#load files
|
|
||||||
import os
|
|
||||||
|
|
||||||
#data visualization
|
|
||||||
from utils import *
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
#quit after interrupt
|
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
dir = 'data'
|
|
||||||
ids = []
|
|
||||||
|
|
||||||
for f in os.listdir(dir + '/train'):
|
|
||||||
id = f[:-4]
|
|
||||||
ids.append([id, 0])
|
|
||||||
ids.append([id, 1])
|
|
||||||
|
|
||||||
np.random.shuffle(ids)
|
|
||||||
#%%
|
|
||||||
|
|
||||||
|
|
||||||
net = UNet(3, 1)
|
|
||||||
net.cuda()
|
|
||||||
|
|
||||||
def train(net):
|
|
||||||
optimizer = optim.Adam(net.parameters(), lr=1)
|
|
||||||
criterion = DiceLoss()
|
|
||||||
|
|
||||||
epochs = 5
|
|
||||||
for epoch in range(epochs):
|
|
||||||
print('epoch {}/{}...'.format(epoch+1, epochs))
|
|
||||||
l = 0
|
|
||||||
|
|
||||||
for i, c in enumerate(ids):
|
|
||||||
id = c[0]
|
|
||||||
pos = c[1]
|
|
||||||
im = PIL.Image.open(dir + '/train/' + id + '.jpg')
|
|
||||||
im = resize_and_crop(im)
|
|
||||||
|
|
||||||
ma = PIL.Image.open(dir + '/train_masks/' + id + '_mask.gif')
|
|
||||||
ma = resize_and_crop(ma)
|
|
||||||
|
|
||||||
left, right = split_into_squares(np.array(im))
|
|
||||||
left_m, right_m = split_into_squares(np.array(ma))
|
|
||||||
|
|
||||||
if pos == 0:
|
|
||||||
X = left
|
|
||||||
y = left_m
|
|
||||||
else:
|
|
||||||
X = right
|
|
||||||
y = right_m
|
|
||||||
|
|
||||||
|
|
||||||
X = np.transpose(X, axes=[2, 0, 1])
|
|
||||||
X = torch.FloatTensor(X / 255).unsqueeze(0).cuda()
|
|
||||||
y = Variable(torch.ByteTensor(y)).cuda()
|
|
||||||
|
|
||||||
X = Variable(X).cuda()
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
y_pred = net(X).squeeze(1)
|
|
||||||
|
|
||||||
|
|
||||||
loss = criterion(y_pred, y.unsqueeze(0).float())
|
|
||||||
|
|
||||||
l += loss.data[0]
|
|
||||||
loss.backward()
|
|
||||||
if i%10 == 0:
|
|
||||||
optimizer.step()
|
|
||||||
print('Stepped')
|
|
||||||
|
|
||||||
print('{0:.4f}%\t\t{1:.6f}'.format(i/len(ids)*100, loss.data[0]))
|
|
||||||
|
|
||||||
l = l / len(ids)
|
|
||||||
print('Loss : {}'.format(l))
|
|
||||||
torch.save(net.state_dict(), 'MODEL_EPOCH{}_LOSS{}.pth'.format(epoch+1, l))
|
|
||||||
print('Saved')
|
|
||||||
|
|
||||||
try:
|
|
||||||
net.load_state_dict(torch.load('MODEL_INTERRUPTED.pth'))
|
|
||||||
train(net)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print('Interrupted')
|
|
||||||
torch.save(net.state_dict(), 'MODEL_INTERRUPTED.pth')
|
|
||||||
try:
|
|
||||||
sys.exit(0)
|
|
||||||
except SystemExit:
|
|
||||||
os._exit(0)
|
|
Loading…
Reference in a new issue