mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
Migration to PyTorch 0.4, code cleanup
Former-commit-id: c981801ccc3b74047e94c76e67c4ff1f3097226c
This commit is contained in:
parent
90e988c10f
commit
02e2314149
|
@ -1,17 +1,12 @@
|
||||||
#
|
|
||||||
# myloss.py : implementation of the Dice coeff and the associated loss
|
|
||||||
#
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.autograd import Function, Variable
|
from torch.autograd import Function, Variable
|
||||||
|
|
||||||
|
|
||||||
class DiceCoeff(Function):
|
class DiceCoeff(Function):
|
||||||
"""Dice coeff for individual examples"""
|
"""Dice coeff for individual examples"""
|
||||||
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
self.save_for_backward(input, target)
|
self.save_for_backward(input, target)
|
||||||
self.inter = torch.dot(input, target) + 0.0001
|
self.inter = torch.dot(input.view(-1), target.view(-1)) + 0.0001
|
||||||
self.union = torch.sum(input) + torch.sum(target) + 0.0001
|
self.union = torch.sum(input) + torch.sum(target) + 0.0001
|
||||||
|
|
||||||
t = 2 * self.inter.float() / self.union.float()
|
t = 2 * self.inter.float() / self.union.float()
|
||||||
|
@ -35,9 +30,9 @@ class DiceCoeff(Function):
|
||||||
def dice_coeff(input, target):
|
def dice_coeff(input, target):
|
||||||
"""Dice coeff for batches"""
|
"""Dice coeff for batches"""
|
||||||
if input.is_cuda:
|
if input.is_cuda:
|
||||||
s = Variable(torch.FloatTensor(1).cuda().zero_())
|
s = torch.FloatTensor(1).cuda().zero_()
|
||||||
else:
|
else:
|
||||||
s = Variable(torch.FloatTensor(1).zero_())
|
s = torch.FloatTensor(1).zero_()
|
||||||
|
|
||||||
for i, c in enumerate(zip(input, target)):
|
for i, c in enumerate(zip(input, target)):
|
||||||
s = s + DiceCoeff().forward(c[0], c[1])
|
s = s + DiceCoeff().forward(c[0], c[1])
|
52
eval.py
52
eval.py
|
@ -1,55 +1,25 @@
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.autograd import Variable
|
|
||||||
|
|
||||||
from myloss import dice_coeff
|
from dice_loss import dice_coeff
|
||||||
from utils import dense_crf
|
|
||||||
|
|
||||||
|
|
||||||
def eval_net(net, dataset, gpu=False):
|
def eval_net(net, dataset, gpu=False):
|
||||||
|
"""Evaluation without the densecrf with the dice coefficient"""
|
||||||
tot = 0
|
tot = 0
|
||||||
for i, b in enumerate(dataset):
|
for i, b in enumerate(dataset):
|
||||||
X = b[0]
|
img = b[0]
|
||||||
y = b[1]
|
true_mask = b[1]
|
||||||
|
|
||||||
X = torch.FloatTensor(X).unsqueeze(0)
|
img = torch.from_numpy(img).unsqueeze(0)
|
||||||
y = torch.ByteTensor(y).unsqueeze(0)
|
true_mask = torch.from_numpy(true_mask).unsqueeze(0)
|
||||||
|
|
||||||
if gpu:
|
if gpu:
|
||||||
X = Variable(X, volatile=True).cuda()
|
img = img.cuda()
|
||||||
y = Variable(y, volatile=True).cuda()
|
true_mask = true_mask.cuda()
|
||||||
else:
|
|
||||||
X = Variable(X, volatile=True)
|
|
||||||
y = Variable(y, volatile=True)
|
|
||||||
|
|
||||||
y_pred = net(X)
|
mask_pred = net(img)[0]
|
||||||
|
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
|
||||||
|
|
||||||
y_pred = (F.sigmoid(y_pred) > 0.6).float()
|
tot += dice_coeff(mask_pred, true_mask).item()
|
||||||
# y_pred = F.sigmoid(y_pred).float()
|
|
||||||
|
|
||||||
dice = dice_coeff(y_pred, y.float()).data[0]
|
|
||||||
tot += dice
|
|
||||||
|
|
||||||
if 0:
|
|
||||||
X = X.data.squeeze(0).cpu().numpy()
|
|
||||||
X = np.transpose(X, axes=[1, 2, 0])
|
|
||||||
y = y.data.squeeze(0).cpu().numpy()
|
|
||||||
y_pred = y_pred.data.squeeze(0).squeeze(0).cpu().numpy()
|
|
||||||
print(y_pred.shape)
|
|
||||||
|
|
||||||
fig = plt.figure()
|
|
||||||
ax1 = fig.add_subplot(1, 4, 1)
|
|
||||||
ax1.imshow(X)
|
|
||||||
ax2 = fig.add_subplot(1, 4, 2)
|
|
||||||
ax2.imshow(y)
|
|
||||||
ax3 = fig.add_subplot(1, 4, 3)
|
|
||||||
ax3.imshow((y_pred > 0.5))
|
|
||||||
|
|
||||||
Q = dense_crf(((X * 255).round()).astype(np.uint8), y_pred)
|
|
||||||
ax4 = fig.add_subplot(1, 4, 4)
|
|
||||||
print(Q)
|
|
||||||
ax4.imshow(Q > 0.5)
|
|
||||||
plt.show()
|
|
||||||
return tot / i
|
return tot / i
|
||||||
|
|
163
predict.py
163
predict.py
|
@ -1,48 +1,64 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.autograd import Variable
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
from utils import *
|
from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf
|
||||||
|
from utils import plot_img_and_mask
|
||||||
|
|
||||||
|
def predict_img(net,
|
||||||
|
full_img,
|
||||||
|
scale_factor=0.5,
|
||||||
|
out_threshold=0.5,
|
||||||
|
use_dense_crf=True,
|
||||||
|
use_gpu=False):
|
||||||
|
|
||||||
|
img_height = full_img.size[1]
|
||||||
|
img_width = full_img.size[0]
|
||||||
|
|
||||||
|
img = resize_and_crop(full_img, scale=scale_factor)
|
||||||
|
img = normalize(img)
|
||||||
|
|
||||||
|
left_square, right_square = split_img_into_squares(img)
|
||||||
|
|
||||||
|
left_square = hwc_to_chw(left_square)
|
||||||
|
right_square = hwc_to_chw(right_square)
|
||||||
|
|
||||||
|
X_left = torch.from_numpy(left_square).unsqueeze(0)
|
||||||
|
X_right = torch.from_numpy(right_square).unsqueeze(0)
|
||||||
|
|
||||||
|
if use_gpu:
|
||||||
|
X_left = X_left.cuda()
|
||||||
|
X_right = X_right.cuda()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output_left = net(X_left)
|
||||||
|
output_right = net(X_right)
|
||||||
|
|
||||||
|
left_probs = F.sigmoid(output_left)
|
||||||
|
right_probs = F.sigmoid(output_right)
|
||||||
|
|
||||||
|
left_probs = F.upsample(left_probs, size=(img_height, img_height))
|
||||||
|
right_probs = F.upsample(right_probs, size=(img_height, img_height))
|
||||||
|
|
||||||
|
left_mask_np = left_probs.squeeze().cpu().numpy()
|
||||||
|
right_mask_np = right_probs.squeeze().cpu().numpy()
|
||||||
|
|
||||||
|
full_mask = merge_masks(left_mask_np, right_mask_np, img_width)
|
||||||
|
|
||||||
|
if use_dense_crf:
|
||||||
|
full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)
|
||||||
|
|
||||||
|
return full_mask > out_threshold
|
||||||
|
|
||||||
|
|
||||||
def predict_img(net, full_img, gpu=False):
|
|
||||||
img = resize_and_crop(full_img)
|
|
||||||
|
|
||||||
left = get_square(img, 0)
|
def get_args():
|
||||||
right = get_square(img, 1)
|
|
||||||
|
|
||||||
right = normalize(right)
|
|
||||||
left = normalize(left)
|
|
||||||
|
|
||||||
right = np.transpose(right, axes=[2, 0, 1])
|
|
||||||
left = np.transpose(left, axes=[2, 0, 1])
|
|
||||||
|
|
||||||
X_l = torch.FloatTensor(left).unsqueeze(0)
|
|
||||||
X_r = torch.FloatTensor(right).unsqueeze(0)
|
|
||||||
|
|
||||||
if gpu:
|
|
||||||
X_l = Variable(X_l, volatile=True).cuda()
|
|
||||||
X_r = Variable(X_r, volatile=True).cuda()
|
|
||||||
else:
|
|
||||||
X_l = Variable(X_l, volatile=True)
|
|
||||||
X_r = Variable(X_r, volatile=True)
|
|
||||||
|
|
||||||
y_l = F.sigmoid(net(X_l))
|
|
||||||
y_r = F.sigmoid(net(X_r))
|
|
||||||
y_l = F.upsample_bilinear(y_l, scale_factor=2).data[0][0].cpu().numpy()
|
|
||||||
y_r = F.upsample_bilinear(y_r, scale_factor=2).data[0][0].cpu().numpy()
|
|
||||||
|
|
||||||
y = merge_masks(y_l, y_r, full_img.size[0])
|
|
||||||
yy = dense_crf(np.array(full_img).astype(np.uint8), y)
|
|
||||||
|
|
||||||
return yy > 0.5
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model', '-m', default='MODEL.pth',
|
parser.add_argument('--model', '-m', default='MODEL.pth',
|
||||||
metavar='FILE',
|
metavar='FILE',
|
||||||
|
@ -61,19 +77,22 @@ if __name__ == "__main__":
|
||||||
parser.add_argument('--no-save', '-n', action='store_false',
|
parser.add_argument('--no-save', '-n', action='store_false',
|
||||||
help="Do not save the output masks",
|
help="Do not save the output masks",
|
||||||
default=False)
|
default=False)
|
||||||
|
parser.add_argument('--no-crf', '-r', action='store_false',
|
||||||
|
help="Do not use dense CRF postprocessing",
|
||||||
|
default=False)
|
||||||
|
parser.add_argument('--mask-threshold', '-t', type=float,
|
||||||
|
help="Minimum probability value to consider a mask pixel white",
|
||||||
|
default=0.5)
|
||||||
|
parser.add_argument('--scale', '-s', type=float,
|
||||||
|
help="Scale factor for the input images",
|
||||||
|
default=0.5)
|
||||||
|
|
||||||
args = parser.parse_args()
|
return parser.parse_args()
|
||||||
print("Using model file : {}".format(args.model))
|
|
||||||
net = UNet(3, 1)
|
|
||||||
if not args.cpu:
|
|
||||||
print("Using CUDA version of the net, prepare your GPU !")
|
|
||||||
net.cuda()
|
|
||||||
else:
|
|
||||||
net.cpu()
|
|
||||||
print("Using CPU version of the net, this may be very slow")
|
|
||||||
|
|
||||||
|
def get_output_filenames(args):
|
||||||
in_files = args.input
|
in_files = args.input
|
||||||
out_files = []
|
out_files = []
|
||||||
|
|
||||||
if not args.output:
|
if not args.output:
|
||||||
for f in in_files:
|
for f in in_files:
|
||||||
pathsplit = os.path.splitext(f)
|
pathsplit = os.path.splitext(f)
|
||||||
|
@ -84,32 +103,52 @@ if __name__ == "__main__":
|
||||||
else:
|
else:
|
||||||
out_files = args.output
|
out_files = args.output
|
||||||
|
|
||||||
print("Loading model ...")
|
return out_files
|
||||||
net.load_state_dict(torch.load(args.model))
|
|
||||||
|
def mask_to_image(mask):
|
||||||
|
return Image.fromarray((mask * 255).astype(np.uint8))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = get_args()
|
||||||
|
in_files = args.input
|
||||||
|
out_files = get_output_filenames(args)
|
||||||
|
|
||||||
|
net = UNet(n_channels=3, n_classes=1)
|
||||||
|
|
||||||
|
print("Loading model {}".format(args.model))
|
||||||
|
|
||||||
|
if not args.cpu:
|
||||||
|
print("Using CUDA version of the net, prepare your GPU !")
|
||||||
|
net.cuda()
|
||||||
|
net.load_state_dict(torch.load(args.model))
|
||||||
|
else:
|
||||||
|
net.cpu()
|
||||||
|
net.load_state_dict(torch.load(args.model, map_location='cpu'))
|
||||||
|
print("Using CPU version of the net, this may be very slow")
|
||||||
|
|
||||||
print("Model loaded !")
|
print("Model loaded !")
|
||||||
|
|
||||||
for i, fn in enumerate(in_files):
|
for i, fn in enumerate(in_files):
|
||||||
print("\nPredicting image {} ...".format(fn))
|
print("\nPredicting image {} ...".format(fn))
|
||||||
|
|
||||||
img = Image.open(fn)
|
img = Image.open(fn)
|
||||||
out = predict_img(net, img, not args.cpu)
|
if img.size[0] < img.size[1]:
|
||||||
|
print("Error: image height larger than the width")
|
||||||
|
|
||||||
|
mask = predict_img(net=net,
|
||||||
|
full_img=img,
|
||||||
|
scale_factor=args.scale,
|
||||||
|
out_threshold=args.mask_threshold,
|
||||||
|
use_dense_crf= not args.no_crf,
|
||||||
|
use_gpu=not args.cpu)
|
||||||
|
|
||||||
if args.viz:
|
if args.viz:
|
||||||
print("Vizualising results for image {}, close to continue ..."
|
print("Visualizing results for image {}, close to continue ...".format(fn))
|
||||||
.format(fn))
|
plot_img_and_mask(img, mask)
|
||||||
|
|
||||||
fig = plt.figure()
|
|
||||||
a = fig.add_subplot(1, 2, 1)
|
|
||||||
a.set_title('Input image')
|
|
||||||
plt.imshow(img)
|
|
||||||
|
|
||||||
b = fig.add_subplot(1, 2, 2)
|
|
||||||
b.set_title('Output mask')
|
|
||||||
plt.imshow(out)
|
|
||||||
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
if not args.no_save:
|
if not args.no_save:
|
||||||
out_fn = out_files[i]
|
out_fn = out_files[i]
|
||||||
result = Image.fromarray((out * 255).astype(numpy.uint8))
|
result = mask_to_image(mask)
|
||||||
result.save(out_files[i])
|
result.save(out_files[i])
|
||||||
|
|
||||||
print("Mask saved to {}".format(out_files[i]))
|
print("Mask saved to {}".format(out_files[i]))
|
||||||
|
|
|
@ -1,10 +1,15 @@
|
||||||
# used to predict all test images and encode results in a csv file
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from predict import *
|
import torch
|
||||||
|
|
||||||
|
from predict import predict_img
|
||||||
|
from utils import rle_encode
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
|
|
||||||
|
|
||||||
def submit(net, gpu=False):
|
def submit(net, gpu=False):
|
||||||
|
"""Used for Kaggle submission: predicts and encode all test images"""
|
||||||
dir = 'data/test/'
|
dir = 'data/test/'
|
||||||
|
|
||||||
N = len(list(os.listdir(dir)))
|
N = len(list(os.listdir(dir)))
|
||||||
|
|
120
train.py
120
train.py
|
@ -1,20 +1,27 @@
|
||||||
import sys
|
import sys
|
||||||
|
import os
|
||||||
from optparse import OptionParser
|
from optparse import OptionParser
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch.autograd import Variable
|
|
||||||
|
|
||||||
from eval import eval_net
|
from eval import eval_net
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
from utils import *
|
from utils import get_ids, split_ids, split_train_val, get_imgs_and_masks, batch
|
||||||
|
|
||||||
|
def train_net(net,
|
||||||
|
epochs=5,
|
||||||
|
batch_size=1,
|
||||||
|
lr=0.1,
|
||||||
|
val_percent=0.05,
|
||||||
|
save_cp=True,
|
||||||
|
gpu=False,
|
||||||
|
img_scale=0.5):
|
||||||
|
|
||||||
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
|
|
||||||
cp=True, gpu=False):
|
|
||||||
dir_img = 'data/train/'
|
dir_img = 'data/train/'
|
||||||
dir_mask = 'data/train_masks/'
|
dir_mask = 'data/train_masks/'
|
||||||
dir_checkpoint = 'checkpoints/'
|
dir_checkpoint = 'checkpoints/'
|
||||||
|
@ -34,69 +41,66 @@ def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
|
||||||
Checkpoints: {}
|
Checkpoints: {}
|
||||||
CUDA: {}
|
CUDA: {}
|
||||||
'''.format(epochs, batch_size, lr, len(iddataset['train']),
|
'''.format(epochs, batch_size, lr, len(iddataset['train']),
|
||||||
len(iddataset['val']), str(cp), str(gpu)))
|
len(iddataset['val']), str(save_cp), str(gpu)))
|
||||||
|
|
||||||
N_train = len(iddataset['train'])
|
N_train = len(iddataset['train'])
|
||||||
|
|
||||||
optimizer = optim.SGD(net.parameters(),
|
optimizer = optim.SGD(net.parameters(),
|
||||||
lr=lr, momentum=0.9, weight_decay=0.0005)
|
lr=lr,
|
||||||
|
momentum=0.9,
|
||||||
|
weight_decay=0.0005)
|
||||||
|
|
||||||
criterion = nn.BCELoss()
|
criterion = nn.BCELoss()
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
|
print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
|
||||||
|
|
||||||
# reset the generators
|
# reset the generators
|
||||||
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
|
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale)
|
||||||
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask)
|
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale)
|
||||||
|
|
||||||
epoch_loss = 0
|
epoch_loss = 0
|
||||||
|
|
||||||
|
for i, b in enumerate(batch(train, batch_size)):
|
||||||
|
imgs = np.array([i[0] for i in b]).astype(np.float32)
|
||||||
|
true_masks = np.array([i[1] for i in b])
|
||||||
|
|
||||||
|
imgs = torch.from_numpy(imgs)
|
||||||
|
true_masks = torch.from_numpy(true_masks)
|
||||||
|
|
||||||
|
if gpu:
|
||||||
|
imgs = imgs.cuda()
|
||||||
|
true_masks = true_masks.cuda()
|
||||||
|
|
||||||
|
masks_pred = net(imgs)
|
||||||
|
masks_probs = F.sigmoid(masks_pred)
|
||||||
|
masks_probs_flat = masks_probs.view(-1)
|
||||||
|
|
||||||
|
true_masks_flat = true_masks.view(-1)
|
||||||
|
|
||||||
|
loss = criterion(masks_probs_flat, true_masks_flat)
|
||||||
|
epoch_loss += loss.item()
|
||||||
|
|
||||||
|
print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item()))
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
print('Epoch finished ! Loss: {}'.format(epoch_loss / i))
|
||||||
|
|
||||||
if 1:
|
if 1:
|
||||||
val_dice = eval_net(net, val, gpu)
|
val_dice = eval_net(net, val, gpu)
|
||||||
print('Validation Dice Coeff: {}'.format(val_dice))
|
print('Validation Dice Coeff: {}'.format(val_dice))
|
||||||
|
|
||||||
for i, b in enumerate(batch(train, batch_size)):
|
if save_cp:
|
||||||
X = np.array([i[0] for i in b])
|
|
||||||
y = np.array([i[1] for i in b])
|
|
||||||
|
|
||||||
X = torch.FloatTensor(X)
|
|
||||||
y = torch.ByteTensor(y)
|
|
||||||
|
|
||||||
if gpu:
|
|
||||||
X = Variable(X).cuda()
|
|
||||||
y = Variable(y).cuda()
|
|
||||||
else:
|
|
||||||
X = Variable(X)
|
|
||||||
y = Variable(y)
|
|
||||||
|
|
||||||
y_pred = net(X)
|
|
||||||
probs = F.sigmoid(y_pred)
|
|
||||||
probs_flat = probs.view(-1)
|
|
||||||
|
|
||||||
y_flat = y.view(-1)
|
|
||||||
|
|
||||||
loss = criterion(probs_flat, y_flat.float())
|
|
||||||
epoch_loss += loss.data[0]
|
|
||||||
|
|
||||||
print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train,
|
|
||||||
loss.data[0]))
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
print('Epoch finished ! Loss: {}'.format(epoch_loss / i))
|
|
||||||
|
|
||||||
if cp:
|
|
||||||
torch.save(net.state_dict(),
|
torch.save(net.state_dict(),
|
||||||
dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
|
dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
|
||||||
|
|
||||||
print('Checkpoint {} saved !'.format(epoch + 1))
|
print('Checkpoint {} saved !'.format(epoch + 1))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
def get_args():
|
||||||
parser = OptionParser()
|
parser = OptionParser()
|
||||||
parser.add_option('-e', '--epochs', dest='epochs', default=5, type='int',
|
parser.add_option('-e', '--epochs', dest='epochs', default=5, type='int',
|
||||||
help='number of epochs')
|
help='number of epochs')
|
||||||
|
@ -108,22 +112,32 @@ if __name__ == '__main__':
|
||||||
default=False, help='use cuda')
|
default=False, help='use cuda')
|
||||||
parser.add_option('-c', '--load', dest='load',
|
parser.add_option('-c', '--load', dest='load',
|
||||||
default=False, help='load file model')
|
default=False, help='load file model')
|
||||||
|
parser.add_option('-s', '--scale', dest='scale', type='float',
|
||||||
|
default=0.5, help='downscaling factor of the images')
|
||||||
|
|
||||||
(options, args) = parser.parse_args()
|
(options, args) = parser.parse_args()
|
||||||
|
return options
|
||||||
|
|
||||||
net = UNet(3, 1)
|
if __name__ == '__main__':
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
if options.load:
|
net = UNet(n_channels=3, n_classes=1)
|
||||||
net.load_state_dict(torch.load(options.load))
|
|
||||||
print('Model loaded from {}'.format(options.load))
|
|
||||||
|
|
||||||
if options.gpu:
|
if args.load:
|
||||||
|
net.load_state_dict(torch.load(args.load))
|
||||||
|
print('Model loaded from {}'.format(args.load))
|
||||||
|
|
||||||
|
if args.gpu:
|
||||||
net.cuda()
|
net.cuda()
|
||||||
cudnn.benchmark = True
|
# cudnn.benchmark = True # faster convolutions, but more memory
|
||||||
|
|
||||||
try:
|
try:
|
||||||
train_net(net, options.epochs, options.batchsize, options.lr,
|
train_net(net=net,
|
||||||
gpu=options.gpu)
|
epochs=args.epochs,
|
||||||
|
batch_size=args.batchsize,
|
||||||
|
lr=args.lr,
|
||||||
|
gpu=args.gpu,
|
||||||
|
img_scale=args.scale)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
torch.save(net.state_dict(), 'INTERRUPTED.pth')
|
torch.save(net.state_dict(), 'INTERRUPTED.pth')
|
||||||
print('Saved interrupt')
|
print('Saved interrupt')
|
||||||
|
|
|
@ -1,14 +1,7 @@
|
||||||
#!/usr/bin/python
|
|
||||||
# full assembly of the sub-parts to form the complete net
|
# full assembly of the sub-parts to form the complete net
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
# python 3 confusing imports :(
|
|
||||||
from .unet_parts import *
|
from .unet_parts import *
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
def __init__(self, n_channels, n_classes):
|
def __init__(self, n_channels, n_classes):
|
||||||
super(UNet, self).__init__()
|
super(UNet, self).__init__()
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
#!/usr/bin/python
|
|
||||||
|
|
||||||
# sub-parts of the U-Net model
|
# sub-parts of the U-Net model
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -53,9 +51,9 @@ class up(nn.Module):
|
||||||
super(up, self).__init__()
|
super(up, self).__init__()
|
||||||
|
|
||||||
# would be a nice idea if the upsampling could be learned too,
|
# would be a nice idea if the upsampling could be learned too,
|
||||||
# but my machine do not have enough memory to handle all those weights
|
# but my machine do not have enough memory to handle all those weights
|
||||||
if bilinear:
|
if bilinear:
|
||||||
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
|
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
||||||
else:
|
else:
|
||||||
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
|
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pydensecrf.densecrf as dcrf
|
import pydensecrf.densecrf as dcrf
|
||||||
|
|
||||||
|
|
||||||
def dense_crf(img, output_probs):
|
def dense_crf(img, output_probs):
|
||||||
h = output_probs.shape[0]
|
h = output_probs.shape[0]
|
||||||
w = output_probs.shape[1]
|
w = output_probs.shape[1]
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
def plot_img_and_mask(img, mask):
|
||||||
def plot_img_mask(img, mask):
|
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
|
a = fig.add_subplot(1, 2, 1)
|
||||||
|
a.set_title('Input image')
|
||||||
|
plt.imshow(img)
|
||||||
|
|
||||||
ax1 = fig.add_subplot(1, 3, 1)
|
b = fig.add_subplot(1, 2, 2)
|
||||||
ax1.imshow(img)
|
b.set_title('Output mask')
|
||||||
|
plt.imshow(mask)
|
||||||
ax2 = fig.add_subplot(1, 3, 2)
|
plt.show()
|
||||||
ax2.imshow(mask)
|
|
||||||
|
|
||||||
plt.show()
|
|
|
@ -3,12 +3,11 @@
|
||||||
# cropped images and masks
|
# cropped images and masks
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from .utils import resize_and_crop, get_square, normalize
|
from .utils import resize_and_crop, get_square, normalize, hwc_to_chw
|
||||||
|
|
||||||
|
|
||||||
def get_ids(dir):
|
def get_ids(dir):
|
||||||
|
@ -21,23 +20,22 @@ def split_ids(ids, n=2):
|
||||||
return ((id, i) for i in range(n) for id in ids)
|
return ((id, i) for i in range(n) for id in ids)
|
||||||
|
|
||||||
|
|
||||||
def to_cropped_imgs(ids, dir, suffix):
|
def to_cropped_imgs(ids, dir, suffix, scale):
|
||||||
"""From a list of tuples, returns the correct cropped img"""
|
"""From a list of tuples, returns the correct cropped img"""
|
||||||
for id, pos in ids:
|
for id, pos in ids:
|
||||||
im = resize_and_crop(Image.open(dir + id + suffix))
|
im = resize_and_crop(Image.open(dir + id + suffix), scale=scale)
|
||||||
yield get_square(im, pos)
|
yield get_square(im, pos)
|
||||||
|
|
||||||
|
def get_imgs_and_masks(ids, dir_img, dir_mask, scale):
|
||||||
def get_imgs_and_masks(ids, dir_img, dir_mask):
|
|
||||||
"""Return all the couples (img, mask)"""
|
"""Return all the couples (img, mask)"""
|
||||||
|
|
||||||
imgs = to_cropped_imgs(ids, dir_img, '.jpg')
|
imgs = to_cropped_imgs(ids, dir_img, '.jpg', scale)
|
||||||
|
|
||||||
# need to transform from HWC to CHW
|
# need to transform from HWC to CHW
|
||||||
imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs)
|
imgs_switched = map(hwc_to_chw, imgs)
|
||||||
imgs_normalized = map(normalize, imgs_switched)
|
imgs_normalized = map(normalize, imgs_switched)
|
||||||
|
|
||||||
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif')
|
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif', scale)
|
||||||
|
|
||||||
return zip(imgs_normalized, masks)
|
return zip(imgs_normalized, masks)
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,20 @@
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def get_square(img, pos):
|
def get_square(img, pos):
|
||||||
"""Extract a left or a right square from PILimg shape : (H, W, C))"""
|
"""Extract a left or a right square from ndarray shape : (H, W, C))"""
|
||||||
img = np.array(img)
|
|
||||||
h = img.shape[0]
|
h = img.shape[0]
|
||||||
if pos == 0:
|
if pos == 0:
|
||||||
return img[:, :h]
|
return img[:, :h]
|
||||||
else:
|
else:
|
||||||
return img[:, -h:]
|
return img[:, -h:]
|
||||||
|
|
||||||
|
def split_img_into_squares(img):
|
||||||
|
return get_square(img, 0), get_square(img, 1)
|
||||||
|
|
||||||
|
def hwc_to_chw(img):
|
||||||
|
return np.transpose(img, axes=[2, 0, 1])
|
||||||
|
|
||||||
def resize_and_crop(pilimg, scale=0.5, final_height=None):
|
def resize_and_crop(pilimg, scale=0.5, final_height=None):
|
||||||
w = pilimg.size[0]
|
w = pilimg.size[0]
|
||||||
|
@ -26,8 +29,7 @@ def resize_and_crop(pilimg, scale=0.5, final_height=None):
|
||||||
|
|
||||||
img = pilimg.resize((newW, newH))
|
img = pilimg.resize((newW, newH))
|
||||||
img = img.crop((0, diff // 2, newW, newH - diff // 2))
|
img = img.crop((0, diff // 2, newW, newH - diff // 2))
|
||||||
return img
|
return np.array(img, dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
def batch(iterable, batch_size):
|
def batch(iterable, batch_size):
|
||||||
"""Yields lists by batch"""
|
"""Yields lists by batch"""
|
||||||
|
@ -41,7 +43,6 @@ def batch(iterable, batch_size):
|
||||||
if len(b) > 0:
|
if len(b) > 0:
|
||||||
yield b
|
yield b
|
||||||
|
|
||||||
|
|
||||||
def split_train_val(dataset, val_percent=0.05):
|
def split_train_val(dataset, val_percent=0.05):
|
||||||
dataset = list(dataset)
|
dataset = list(dataset)
|
||||||
length = len(dataset)
|
length = len(dataset)
|
||||||
|
@ -53,18 +54,17 @@ def split_train_val(dataset, val_percent=0.05):
|
||||||
def normalize(x):
|
def normalize(x):
|
||||||
return x / 255
|
return x / 255
|
||||||
|
|
||||||
|
|
||||||
def merge_masks(img1, img2, full_w):
|
def merge_masks(img1, img2, full_w):
|
||||||
h = img1.shape[0]
|
h = img1.shape[0]
|
||||||
|
|
||||||
new = np.zeros((h, full_w), np.float32)
|
new = np.zeros((h, full_w), np.float32)
|
||||||
|
|
||||||
new[:, :full_w // 2 + 1] = img1[:, :full_w // 2 + 1]
|
new[:, :full_w // 2 + 1] = img1[:, :full_w // 2 + 1]
|
||||||
new[:, full_w // 2 + 1:] = img2[:, -(full_w // 2 - 1):]
|
new[:, full_w // 2 + 1:] = img2[:, -(full_w // 2 - 1):]
|
||||||
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
# credits to https://stackoverflow.com/users/6076729/manuel-lagunas
|
||||||
def rle_encode(mask_image):
|
def rle_encode(mask_image):
|
||||||
pixels = mask_image.flatten()
|
pixels = mask_image.flatten()
|
||||||
# We avoid issues with '1' at the start or end (at the corners of
|
# We avoid issues with '1' at the start or end (at the corners of
|
||||||
|
|
Loading…
Reference in a new issue