mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 22:42:02 +00:00
c8c82204bf
slow, no gpu, no validation
102 lines
1.8 KiB
Python
102 lines
1.8 KiB
Python
#models
|
|
from unet_model import UNet
|
|
from myloss import *
|
|
import torch
|
|
from torch.autograd import Variable
|
|
from torch import optim
|
|
|
|
#data manipulation
|
|
import numpy as np
|
|
import pandas as pd
|
|
import cv2
|
|
import PIL
|
|
|
|
#load files
|
|
import os
|
|
|
|
#data vis
|
|
from data_vis import plot_img_mask
|
|
from utils import *
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
dir = 'data'
|
|
ids = []
|
|
|
|
for f in os.listdir(dir + '/train'):
|
|
id = f[:-4]
|
|
ids.append([id, 0])
|
|
ids.append([id, 1])
|
|
|
|
np.random.shuffle(ids)
|
|
#%%
|
|
|
|
|
|
net = UNet(3, 1)
|
|
|
|
optimizer = optim.Adam(net.parameters(), lr=0.001)
|
|
criterion = DiceLoss()
|
|
|
|
dataset = []
|
|
epochs = 5
|
|
for epoch in range(epochs):
|
|
print('epoch {}/{}...'.format(epoch+1, epochs))
|
|
l = 0
|
|
|
|
for i, c in enumerate(ids):
|
|
id = c[0]
|
|
pos = c[1]
|
|
im = PIL.Image.open(dir + '/train/' + id + '.jpg')
|
|
im = resize_and_crop(im)
|
|
|
|
ma = PIL.Image.open(dir + '/train_masks/' + id + '_mask.gif')
|
|
ma = resize_and_crop(ma)
|
|
|
|
left, right = split_into_squares(np.array(im))
|
|
left_m, right_m = split_into_squares(np.array(ma))
|
|
|
|
if pos == 0:
|
|
X = left
|
|
y = left_m
|
|
else:
|
|
X = right
|
|
y = right_m
|
|
|
|
|
|
X = np.transpose(X, axes=[2, 0, 1])
|
|
X = torch.FloatTensor(X / 255).unsqueeze(0)
|
|
y = Variable(torch.ByteTensor(y))
|
|
|
|
X = Variable(X, requires_grad=False)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
y_pred = net(X).squeeze(1)
|
|
|
|
|
|
loss = criterion(y_pred, y.unsqueeze(0).float())
|
|
|
|
l += loss.data[0]
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
print('{0:.4f}%.'.format(i/len(ids)*100, end=''))
|
|
|
|
print('Loss : {}'.format(l))
|
|
|
|
|
|
#%%
|
|
|
|
|
|
|
|
|
|
#net = UNet(3, 2)
|
|
|
|
#x = Variable(torch.FloatTensor(np.random.randn(1, 3, 640, 640)))
|
|
|
|
#y = net(x)
|
|
|
|
|
|
#plt.imshow(y[0])
|
|
#plt.show()
|