Migration to PyTorch 0.4, code cleanup
Former-commit-id: c981801ccc3b74047e94c76e67c4ff1f3097226c
This commit is contained in:
parent
90e988c10f
commit
02e2314149
|
@ -1,17 +1,12 @@
|
|||
#
|
||||
# myloss.py : implementation of the Dice coeff and the associated loss
|
||||
#
|
||||
|
||||
import torch
|
||||
from torch.autograd import Function, Variable
|
||||
|
||||
|
||||
class DiceCoeff(Function):
|
||||
"""Dice coeff for individual examples"""
|
||||
|
||||
def forward(self, input, target):
|
||||
self.save_for_backward(input, target)
|
||||
self.inter = torch.dot(input, target) + 0.0001
|
||||
self.inter = torch.dot(input.view(-1), target.view(-1)) + 0.0001
|
||||
self.union = torch.sum(input) + torch.sum(target) + 0.0001
|
||||
|
||||
t = 2 * self.inter.float() / self.union.float()
|
||||
|
@ -35,9 +30,9 @@ class DiceCoeff(Function):
|
|||
def dice_coeff(input, target):
|
||||
"""Dice coeff for batches"""
|
||||
if input.is_cuda:
|
||||
s = Variable(torch.FloatTensor(1).cuda().zero_())
|
||||
s = torch.FloatTensor(1).cuda().zero_()
|
||||
else:
|
||||
s = Variable(torch.FloatTensor(1).zero_())
|
||||
s = torch.FloatTensor(1).zero_()
|
||||
|
||||
for i, c in enumerate(zip(input, target)):
|
||||
s = s + DiceCoeff().forward(c[0], c[1])
|
52
eval.py
52
eval.py
|
@ -1,55 +1,25 @@
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
from myloss import dice_coeff
|
||||
from utils import dense_crf
|
||||
from dice_loss import dice_coeff
|
||||
|
||||
|
||||
def eval_net(net, dataset, gpu=False):
|
||||
"""Evaluation without the densecrf with the dice coefficient"""
|
||||
tot = 0
|
||||
for i, b in enumerate(dataset):
|
||||
X = b[0]
|
||||
y = b[1]
|
||||
img = b[0]
|
||||
true_mask = b[1]
|
||||
|
||||
X = torch.FloatTensor(X).unsqueeze(0)
|
||||
y = torch.ByteTensor(y).unsqueeze(0)
|
||||
img = torch.from_numpy(img).unsqueeze(0)
|
||||
true_mask = torch.from_numpy(true_mask).unsqueeze(0)
|
||||
|
||||
if gpu:
|
||||
X = Variable(X, volatile=True).cuda()
|
||||
y = Variable(y, volatile=True).cuda()
|
||||
else:
|
||||
X = Variable(X, volatile=True)
|
||||
y = Variable(y, volatile=True)
|
||||
img = img.cuda()
|
||||
true_mask = true_mask.cuda()
|
||||
|
||||
y_pred = net(X)
|
||||
mask_pred = net(img)[0]
|
||||
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
|
||||
|
||||
y_pred = (F.sigmoid(y_pred) > 0.6).float()
|
||||
# y_pred = F.sigmoid(y_pred).float()
|
||||
|
||||
dice = dice_coeff(y_pred, y.float()).data[0]
|
||||
tot += dice
|
||||
|
||||
if 0:
|
||||
X = X.data.squeeze(0).cpu().numpy()
|
||||
X = np.transpose(X, axes=[1, 2, 0])
|
||||
y = y.data.squeeze(0).cpu().numpy()
|
||||
y_pred = y_pred.data.squeeze(0).squeeze(0).cpu().numpy()
|
||||
print(y_pred.shape)
|
||||
|
||||
fig = plt.figure()
|
||||
ax1 = fig.add_subplot(1, 4, 1)
|
||||
ax1.imshow(X)
|
||||
ax2 = fig.add_subplot(1, 4, 2)
|
||||
ax2.imshow(y)
|
||||
ax3 = fig.add_subplot(1, 4, 3)
|
||||
ax3.imshow((y_pred > 0.5))
|
||||
|
||||
Q = dense_crf(((X * 255).round()).astype(np.uint8), y_pred)
|
||||
ax4 = fig.add_subplot(1, 4, 4)
|
||||
print(Q)
|
||||
ax4.imshow(Q > 0.5)
|
||||
plt.show()
|
||||
tot += dice_coeff(mask_pred, true_mask).item()
|
||||
return tot / i
|
||||
|
|
163
predict.py
163
predict.py
|
@ -1,48 +1,64 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from unet import UNet
|
||||
from utils import *
|
||||
from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf
|
||||
from utils import plot_img_and_mask
|
||||
|
||||
def predict_img(net,
|
||||
full_img,
|
||||
scale_factor=0.5,
|
||||
out_threshold=0.5,
|
||||
use_dense_crf=True,
|
||||
use_gpu=False):
|
||||
|
||||
img_height = full_img.size[1]
|
||||
img_width = full_img.size[0]
|
||||
|
||||
img = resize_and_crop(full_img, scale=scale_factor)
|
||||
img = normalize(img)
|
||||
|
||||
left_square, right_square = split_img_into_squares(img)
|
||||
|
||||
left_square = hwc_to_chw(left_square)
|
||||
right_square = hwc_to_chw(right_square)
|
||||
|
||||
X_left = torch.from_numpy(left_square).unsqueeze(0)
|
||||
X_right = torch.from_numpy(right_square).unsqueeze(0)
|
||||
|
||||
if use_gpu:
|
||||
X_left = X_left.cuda()
|
||||
X_right = X_right.cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
output_left = net(X_left)
|
||||
output_right = net(X_right)
|
||||
|
||||
left_probs = F.sigmoid(output_left)
|
||||
right_probs = F.sigmoid(output_right)
|
||||
|
||||
left_probs = F.upsample(left_probs, size=(img_height, img_height))
|
||||
right_probs = F.upsample(right_probs, size=(img_height, img_height))
|
||||
|
||||
left_mask_np = left_probs.squeeze().cpu().numpy()
|
||||
right_mask_np = right_probs.squeeze().cpu().numpy()
|
||||
|
||||
full_mask = merge_masks(left_mask_np, right_mask_np, img_width)
|
||||
|
||||
if use_dense_crf:
|
||||
full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)
|
||||
|
||||
return full_mask > out_threshold
|
||||
|
||||
|
||||
def predict_img(net, full_img, gpu=False):
|
||||
img = resize_and_crop(full_img)
|
||||
|
||||
left = get_square(img, 0)
|
||||
right = get_square(img, 1)
|
||||
|
||||
right = normalize(right)
|
||||
left = normalize(left)
|
||||
|
||||
right = np.transpose(right, axes=[2, 0, 1])
|
||||
left = np.transpose(left, axes=[2, 0, 1])
|
||||
|
||||
X_l = torch.FloatTensor(left).unsqueeze(0)
|
||||
X_r = torch.FloatTensor(right).unsqueeze(0)
|
||||
|
||||
if gpu:
|
||||
X_l = Variable(X_l, volatile=True).cuda()
|
||||
X_r = Variable(X_r, volatile=True).cuda()
|
||||
else:
|
||||
X_l = Variable(X_l, volatile=True)
|
||||
X_r = Variable(X_r, volatile=True)
|
||||
|
||||
y_l = F.sigmoid(net(X_l))
|
||||
y_r = F.sigmoid(net(X_r))
|
||||
y_l = F.upsample_bilinear(y_l, scale_factor=2).data[0][0].cpu().numpy()
|
||||
y_r = F.upsample_bilinear(y_r, scale_factor=2).data[0][0].cpu().numpy()
|
||||
|
||||
y = merge_masks(y_l, y_r, full_img.size[0])
|
||||
yy = dense_crf(np.array(full_img).astype(np.uint8), y)
|
||||
|
||||
return yy > 0.5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', '-m', default='MODEL.pth',
|
||||
metavar='FILE',
|
||||
|
@ -61,19 +77,22 @@ if __name__ == "__main__":
|
|||
parser.add_argument('--no-save', '-n', action='store_false',
|
||||
help="Do not save the output masks",
|
||||
default=False)
|
||||
parser.add_argument('--no-crf', '-r', action='store_false',
|
||||
help="Do not use dense CRF postprocessing",
|
||||
default=False)
|
||||
parser.add_argument('--mask-threshold', '-t', type=float,
|
||||
help="Minimum probability value to consider a mask pixel white",
|
||||
default=0.5)
|
||||
parser.add_argument('--scale', '-s', type=float,
|
||||
help="Scale factor for the input images",
|
||||
default=0.5)
|
||||
|
||||
args = parser.parse_args()
|
||||
print("Using model file : {}".format(args.model))
|
||||
net = UNet(3, 1)
|
||||
if not args.cpu:
|
||||
print("Using CUDA version of the net, prepare your GPU !")
|
||||
net.cuda()
|
||||
else:
|
||||
net.cpu()
|
||||
print("Using CPU version of the net, this may be very slow")
|
||||
return parser.parse_args()
|
||||
|
||||
def get_output_filenames(args):
|
||||
in_files = args.input
|
||||
out_files = []
|
||||
|
||||
if not args.output:
|
||||
for f in in_files:
|
||||
pathsplit = os.path.splitext(f)
|
||||
|
@ -84,32 +103,52 @@ if __name__ == "__main__":
|
|||
else:
|
||||
out_files = args.output
|
||||
|
||||
print("Loading model ...")
|
||||
net.load_state_dict(torch.load(args.model))
|
||||
return out_files
|
||||
|
||||
def mask_to_image(mask):
|
||||
return Image.fromarray((mask * 255).astype(np.uint8))
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
in_files = args.input
|
||||
out_files = get_output_filenames(args)
|
||||
|
||||
net = UNet(n_channels=3, n_classes=1)
|
||||
|
||||
print("Loading model {}".format(args.model))
|
||||
|
||||
if not args.cpu:
|
||||
print("Using CUDA version of the net, prepare your GPU !")
|
||||
net.cuda()
|
||||
net.load_state_dict(torch.load(args.model))
|
||||
else:
|
||||
net.cpu()
|
||||
net.load_state_dict(torch.load(args.model, map_location='cpu'))
|
||||
print("Using CPU version of the net, this may be very slow")
|
||||
|
||||
print("Model loaded !")
|
||||
|
||||
for i, fn in enumerate(in_files):
|
||||
print("\nPredicting image {} ...".format(fn))
|
||||
|
||||
img = Image.open(fn)
|
||||
out = predict_img(net, img, not args.cpu)
|
||||
if img.size[0] < img.size[1]:
|
||||
print("Error: image height larger than the width")
|
||||
|
||||
mask = predict_img(net=net,
|
||||
full_img=img,
|
||||
scale_factor=args.scale,
|
||||
out_threshold=args.mask_threshold,
|
||||
use_dense_crf= not args.no_crf,
|
||||
use_gpu=not args.cpu)
|
||||
|
||||
if args.viz:
|
||||
print("Vizualising results for image {}, close to continue ..."
|
||||
.format(fn))
|
||||
|
||||
fig = plt.figure()
|
||||
a = fig.add_subplot(1, 2, 1)
|
||||
a.set_title('Input image')
|
||||
plt.imshow(img)
|
||||
|
||||
b = fig.add_subplot(1, 2, 2)
|
||||
b.set_title('Output mask')
|
||||
plt.imshow(out)
|
||||
|
||||
plt.show()
|
||||
print("Visualizing results for image {}, close to continue ...".format(fn))
|
||||
plot_img_and_mask(img, mask)
|
||||
|
||||
if not args.no_save:
|
||||
out_fn = out_files[i]
|
||||
result = Image.fromarray((out * 255).astype(numpy.uint8))
|
||||
result = mask_to_image(mask)
|
||||
result.save(out_files[i])
|
||||
|
||||
print("Mask saved to {}".format(out_files[i]))
|
||||
|
|
|
@ -1,10 +1,15 @@
|
|||
# used to predict all test images and encode results in a csv file
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
from predict import *
|
||||
import torch
|
||||
|
||||
from predict import predict_img
|
||||
from utils import rle_encode
|
||||
from unet import UNet
|
||||
|
||||
|
||||
def submit(net, gpu=False):
|
||||
"""Used for Kaggle submission: predicts and encode all test images"""
|
||||
dir = 'data/test/'
|
||||
|
||||
N = len(list(os.listdir(dir)))
|
||||
|
|
120
train.py
120
train.py
|
@ -1,20 +1,27 @@
|
|||
import sys
|
||||
import os
|
||||
from optparse import OptionParser
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import optim
|
||||
from torch.autograd import Variable
|
||||
|
||||
from eval import eval_net
|
||||
from unet import UNet
|
||||
from utils import *
|
||||
from utils import get_ids, split_ids, split_train_val, get_imgs_and_masks, batch
|
||||
|
||||
def train_net(net,
|
||||
epochs=5,
|
||||
batch_size=1,
|
||||
lr=0.1,
|
||||
val_percent=0.05,
|
||||
save_cp=True,
|
||||
gpu=False,
|
||||
img_scale=0.5):
|
||||
|
||||
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
|
||||
cp=True, gpu=False):
|
||||
dir_img = 'data/train/'
|
||||
dir_mask = 'data/train_masks/'
|
||||
dir_checkpoint = 'checkpoints/'
|
||||
|
@ -34,69 +41,66 @@ def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
|
|||
Checkpoints: {}
|
||||
CUDA: {}
|
||||
'''.format(epochs, batch_size, lr, len(iddataset['train']),
|
||||
len(iddataset['val']), str(cp), str(gpu)))
|
||||
len(iddataset['val']), str(save_cp), str(gpu)))
|
||||
|
||||
N_train = len(iddataset['train'])
|
||||
|
||||
optimizer = optim.SGD(net.parameters(),
|
||||
lr=lr, momentum=0.9, weight_decay=0.0005)
|
||||
lr=lr,
|
||||
momentum=0.9,
|
||||
weight_decay=0.0005)
|
||||
|
||||
criterion = nn.BCELoss()
|
||||
|
||||
for epoch in range(epochs):
|
||||
print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
|
||||
|
||||
# reset the generators
|
||||
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
|
||||
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask)
|
||||
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale)
|
||||
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale)
|
||||
|
||||
epoch_loss = 0
|
||||
|
||||
for i, b in enumerate(batch(train, batch_size)):
|
||||
imgs = np.array([i[0] for i in b]).astype(np.float32)
|
||||
true_masks = np.array([i[1] for i in b])
|
||||
|
||||
imgs = torch.from_numpy(imgs)
|
||||
true_masks = torch.from_numpy(true_masks)
|
||||
|
||||
if gpu:
|
||||
imgs = imgs.cuda()
|
||||
true_masks = true_masks.cuda()
|
||||
|
||||
masks_pred = net(imgs)
|
||||
masks_probs = F.sigmoid(masks_pred)
|
||||
masks_probs_flat = masks_probs.view(-1)
|
||||
|
||||
true_masks_flat = true_masks.view(-1)
|
||||
|
||||
loss = criterion(masks_probs_flat, true_masks_flat)
|
||||
epoch_loss += loss.item()
|
||||
|
||||
print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item()))
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print('Epoch finished ! Loss: {}'.format(epoch_loss / i))
|
||||
|
||||
if 1:
|
||||
val_dice = eval_net(net, val, gpu)
|
||||
print('Validation Dice Coeff: {}'.format(val_dice))
|
||||
|
||||
for i, b in enumerate(batch(train, batch_size)):
|
||||
X = np.array([i[0] for i in b])
|
||||
y = np.array([i[1] for i in b])
|
||||
|
||||
X = torch.FloatTensor(X)
|
||||
y = torch.ByteTensor(y)
|
||||
|
||||
if gpu:
|
||||
X = Variable(X).cuda()
|
||||
y = Variable(y).cuda()
|
||||
else:
|
||||
X = Variable(X)
|
||||
y = Variable(y)
|
||||
|
||||
y_pred = net(X)
|
||||
probs = F.sigmoid(y_pred)
|
||||
probs_flat = probs.view(-1)
|
||||
|
||||
y_flat = y.view(-1)
|
||||
|
||||
loss = criterion(probs_flat, y_flat.float())
|
||||
epoch_loss += loss.data[0]
|
||||
|
||||
print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train,
|
||||
loss.data[0]))
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
print('Epoch finished ! Loss: {}'.format(epoch_loss / i))
|
||||
|
||||
if cp:
|
||||
if save_cp:
|
||||
torch.save(net.state_dict(),
|
||||
dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
|
||||
|
||||
print('Checkpoint {} saved !'.format(epoch + 1))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def get_args():
|
||||
parser = OptionParser()
|
||||
parser.add_option('-e', '--epochs', dest='epochs', default=5, type='int',
|
||||
help='number of epochs')
|
||||
|
@ -108,22 +112,32 @@ if __name__ == '__main__':
|
|||
default=False, help='use cuda')
|
||||
parser.add_option('-c', '--load', dest='load',
|
||||
default=False, help='load file model')
|
||||
parser.add_option('-s', '--scale', dest='scale', type='float',
|
||||
default=0.5, help='downscaling factor of the images')
|
||||
|
||||
(options, args) = parser.parse_args()
|
||||
return options
|
||||
|
||||
net = UNet(3, 1)
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
|
||||
if options.load:
|
||||
net.load_state_dict(torch.load(options.load))
|
||||
print('Model loaded from {}'.format(options.load))
|
||||
net = UNet(n_channels=3, n_classes=1)
|
||||
|
||||
if options.gpu:
|
||||
if args.load:
|
||||
net.load_state_dict(torch.load(args.load))
|
||||
print('Model loaded from {}'.format(args.load))
|
||||
|
||||
if args.gpu:
|
||||
net.cuda()
|
||||
cudnn.benchmark = True
|
||||
# cudnn.benchmark = True # faster convolutions, but more memory
|
||||
|
||||
try:
|
||||
train_net(net, options.epochs, options.batchsize, options.lr,
|
||||
gpu=options.gpu)
|
||||
train_net(net=net,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batchsize,
|
||||
lr=args.lr,
|
||||
gpu=args.gpu,
|
||||
img_scale=args.scale)
|
||||
except KeyboardInterrupt:
|
||||
torch.save(net.state_dict(), 'INTERRUPTED.pth')
|
||||
print('Saved interrupt')
|
||||
|
|
|
@ -1,14 +1,7 @@
|
|||
#!/usr/bin/python
|
||||
# full assembly of the sub-parts to form the complete net
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# python 3 confusing imports :(
|
||||
from .unet_parts import *
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, n_channels, n_classes):
|
||||
super(UNet, self).__init__()
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
# sub-parts of the U-Net model
|
||||
|
||||
import torch
|
||||
|
@ -53,9 +51,9 @@ class up(nn.Module):
|
|||
super(up, self).__init__()
|
||||
|
||||
# would be a nice idea if the upsampling could be learned too,
|
||||
# but my machine do not have enough memory to handle all those weights
|
||||
# but my machine do not have enough memory to handle all those weights
|
||||
if bilinear:
|
||||
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
|
||||
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
||||
else:
|
||||
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import numpy as np
|
||||
import pydensecrf.densecrf as dcrf
|
||||
|
||||
|
||||
def dense_crf(img, output_probs):
|
||||
h = output_probs.shape[0]
|
||||
w = output_probs.shape[1]
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def plot_img_mask(img, mask):
|
||||
def plot_img_and_mask(img, mask):
|
||||
fig = plt.figure()
|
||||
a = fig.add_subplot(1, 2, 1)
|
||||
a.set_title('Input image')
|
||||
plt.imshow(img)
|
||||
|
||||
ax1 = fig.add_subplot(1, 3, 1)
|
||||
ax1.imshow(img)
|
||||
|
||||
ax2 = fig.add_subplot(1, 3, 2)
|
||||
ax2.imshow(mask)
|
||||
|
||||
plt.show()
|
||||
b = fig.add_subplot(1, 2, 2)
|
||||
b.set_title('Output mask')
|
||||
plt.imshow(mask)
|
||||
plt.show()
|
|
@ -3,12 +3,11 @@
|
|||
# cropped images and masks
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from .utils import resize_and_crop, get_square, normalize
|
||||
from .utils import resize_and_crop, get_square, normalize, hwc_to_chw
|
||||
|
||||
|
||||
def get_ids(dir):
|
||||
|
@ -21,23 +20,22 @@ def split_ids(ids, n=2):
|
|||
return ((id, i) for i in range(n) for id in ids)
|
||||
|
||||
|
||||
def to_cropped_imgs(ids, dir, suffix):
|
||||
def to_cropped_imgs(ids, dir, suffix, scale):
|
||||
"""From a list of tuples, returns the correct cropped img"""
|
||||
for id, pos in ids:
|
||||
im = resize_and_crop(Image.open(dir + id + suffix))
|
||||
im = resize_and_crop(Image.open(dir + id + suffix), scale=scale)
|
||||
yield get_square(im, pos)
|
||||
|
||||
|
||||
def get_imgs_and_masks(ids, dir_img, dir_mask):
|
||||
def get_imgs_and_masks(ids, dir_img, dir_mask, scale):
|
||||
"""Return all the couples (img, mask)"""
|
||||
|
||||
imgs = to_cropped_imgs(ids, dir_img, '.jpg')
|
||||
imgs = to_cropped_imgs(ids, dir_img, '.jpg', scale)
|
||||
|
||||
# need to transform from HWC to CHW
|
||||
imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs)
|
||||
imgs_switched = map(hwc_to_chw, imgs)
|
||||
imgs_normalized = map(normalize, imgs_switched)
|
||||
|
||||
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif')
|
||||
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif', scale)
|
||||
|
||||
return zip(imgs_normalized, masks)
|
||||
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_square(img, pos):
|
||||
"""Extract a left or a right square from PILimg shape : (H, W, C))"""
|
||||
img = np.array(img)
|
||||
"""Extract a left or a right square from ndarray shape : (H, W, C))"""
|
||||
h = img.shape[0]
|
||||
if pos == 0:
|
||||
return img[:, :h]
|
||||
else:
|
||||
return img[:, -h:]
|
||||
|
||||
def split_img_into_squares(img):
|
||||
return get_square(img, 0), get_square(img, 1)
|
||||
|
||||
def hwc_to_chw(img):
|
||||
return np.transpose(img, axes=[2, 0, 1])
|
||||
|
||||
def resize_and_crop(pilimg, scale=0.5, final_height=None):
|
||||
w = pilimg.size[0]
|
||||
|
@ -26,8 +29,7 @@ def resize_and_crop(pilimg, scale=0.5, final_height=None):
|
|||
|
||||
img = pilimg.resize((newW, newH))
|
||||
img = img.crop((0, diff // 2, newW, newH - diff // 2))
|
||||
return img
|
||||
|
||||
return np.array(img, dtype=np.float32)
|
||||
|
||||
def batch(iterable, batch_size):
|
||||
"""Yields lists by batch"""
|
||||
|
@ -41,7 +43,6 @@ def batch(iterable, batch_size):
|
|||
if len(b) > 0:
|
||||
yield b
|
||||
|
||||
|
||||
def split_train_val(dataset, val_percent=0.05):
|
||||
dataset = list(dataset)
|
||||
length = len(dataset)
|
||||
|
@ -53,18 +54,17 @@ def split_train_val(dataset, val_percent=0.05):
|
|||
def normalize(x):
|
||||
return x / 255
|
||||
|
||||
|
||||
def merge_masks(img1, img2, full_w):
|
||||
h = img1.shape[0]
|
||||
|
||||
new = np.zeros((h, full_w), np.float32)
|
||||
|
||||
new[:, :full_w // 2 + 1] = img1[:, :full_w // 2 + 1]
|
||||
new[:, full_w // 2 + 1:] = img2[:, -(full_w // 2 - 1):]
|
||||
|
||||
return new
|
||||
|
||||
|
||||
# credits to https://stackoverflow.com/users/6076729/manuel-lagunas
|
||||
def rle_encode(mask_image):
|
||||
pixels = mask_image.flatten()
|
||||
# We avoid issues with '1' at the start or end (at the corners of
|
||||
|
|
Loading…
Reference in a new issue