Global cleanup, better logging and CLI

Former-commit-id: ff1ac0936c118d129bc8a8014958948d3b3883be
This commit is contained in:
milesial 2019-10-24 21:37:21 +02:00
parent 74f825ce06
commit 4c0f0a7a7b
15 changed files with 311 additions and 285 deletions

View file

@ -1 +0,0 @@
408f675eb803bd50727626d588144df3f99e6234

View file

@ -1,12 +1,13 @@
# Pytorch-UNet
# UNet: semantic segmentation with PyTorch
![input and output for a random image in the test dataset](https://framapic.org/OcE8HlU6me61/KNTt8GFQzxDR.png)
Customized implementation of the [U-Net](https://arxiv.org/pdf/1505.04597.pdf) in Pytorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge) from a high definition image. This was used with only one output class but it can be scaled easily.
Customized implementation of the [U-Net](https://arxiv.org/abs/1505.04597) in PyTorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge) from high definition images.
This model was trained from scratch with 5000 images (no data augmentation) and scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 (511 out of 735) on over 100k test images. This score is not quite good but could be improved with more training, data augmentation, fine tuning, playing with CRF post-processing, and applying more weights on the edges of the masks.
This model was trained from scratch with 5000 images (no data augmentation) and scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 (511 out of 735) on over 100k test images. This score could be improved with more training, data augmentation, fine tuning, playing with CRF post-processing, and applying more weights on the edges of the masks.
The model used for the last submission is stored in the `MODEL.pth` file, if you wish to play with it. The data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data).
The Carvana data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data).
## Usage
**Note : Use Python 3**
@ -14,9 +15,6 @@ The model used for the last submission is stored in the `MODEL.pth` file, if you
You can easily test the output masks on your images via the CLI.
To see all options:
`python predict.py -h`
To predict a single image and save it:
`python predict.py -i image.jpg -o output.jpg`
@ -25,15 +23,61 @@ To predict a multiple images and show them without saving them:
`python predict.py -i image1.jpg image2.jpg --viz --no-save`
You can use the cpu-only version with `--cpu`.
```shell script
> python predict.py -h
usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...]
[--output INPUT [INPUT ...]] [--viz] [--no-save]
[--mask-threshold MASK_THRESHOLD] [--scale SCALE]
Predict masks from input images
optional arguments:
-h, --help show this help message and exit
--model FILE, -m FILE
Specify the file in which the model is stored
(default: MODEL.pth)
--input INPUT [INPUT ...], -i INPUT [INPUT ...]
filenames of input images (default: None)
--output INPUT [INPUT ...], -o INPUT [INPUT ...]
Filenames of ouput images (default: None)
--viz, -v Visualize the images as they are processed (default:
False)
--no-save, -n Do not save the output masks (default: False)
--mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD
Minimum probability value to consider a mask pixel
white (default: 0.5)
--scale SCALE, -s SCALE
Scale factor for the input images (default: 0.5)
```
You can specify which model file to use with `--model MODEL.pth`.
### Training
`python train.py -h` should get you started. A proper CLI is yet to be added.
## Warning
In order to process the image, it is split into two squares (a left on and a right one), and each square is passed into the net. The two square masks are then merged again to produce the final image. As a consequence, the height of the image must be strictly superior than half the width. Make sure the width is even too.
```shell script
> python train.py -h
usage: train.py [-h] [-e E] [-b [B]] [-l [LR]] [-f LOAD] [-s SCALE] [-v VAL]
Train the UNet on images and target masks
optional arguments:
-h, --help show this help message and exit
-e E, --epochs E Number of epochs (default: 5)
-b [B], --batch-size [B]
Batch size (default: 1)
-l [LR], --learning-rate [LR]
Learning rate (default: 0.1)
-f LOAD, --load LOAD Load model from a .pth file (default: False)
-s SCALE, --scale SCALE
Downscaling factor of the images (default: 0.5)
-v VAL, --validation VAL
Percent of the data that is used as validation (0-100)
(default: 15.0)
```
By default, the `scale` is 0.5, so if you wish to obtain better results (but use more memory), set it to 1.
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively.
## Dependencies
This package depends on [pydensecrf](https://github.com/lucasb-eyer/pydensecrf), available via `pip install`.
@ -42,5 +86,11 @@ This package depends on [pydensecrf](https://github.com/lucasb-eyer/pydensecrf),
The model has be trained from scratch on a GTX970M 3GB.
Predicting images of 1918*1280 takes 1.5GB of memory.
Training takes approximately 3GB, so if you are a few MB shy of memory, consider turning off all graphical displays.
Training takes much approximately 3GB, so if you are a few MB shy of memory, consider turning off all graphical displays.
This assumes you use bilinear up-sampling, and not transposed convolution in the model.
---
Original paper by Olaf Ronneberger, Philipp Fischer, Thomas Brox: [https://arxiv.org/abs/1505.04597](https://arxiv.org/abs/1505.04597)
![network architecture](https://i.imgur.com/jeDVpqF.png)

0
data/imgs/.keep Normal file
View file

0
data/masks/.keep Normal file
View file

View file

@ -1,5 +1,6 @@
import torch
from torch.autograd import Function, Variable
from torch.autograd import Function
class DiceCoeff(Function):
"""Dice coeff for individual examples"""

20
eval.py
View file

@ -1,26 +1,26 @@
import torch
import torch.nn.functional as F
from tqdm import tqdm
from dice_loss import dice_coeff
def eval_net(net, dataset, gpu=False):
def eval_net(net, dataset, device, n_val):
"""Evaluation without the densecrf with the dice coefficient"""
net.eval()
tot = 0
for i, b in enumerate(dataset):
for i, b in tqdm(enumerate(dataset), total=n_val, desc='Validation round', unit='img'):
img = b[0]
true_mask = b[1]
img = torch.from_numpy(img).unsqueeze(0)
true_mask = torch.from_numpy(true_mask).unsqueeze(0)
if gpu:
img = img.cuda()
true_mask = true_mask.cuda()
img = img.to(device=device)
true_mask = true_mask.to(device=device)
mask_pred = net(img).squeeze(dim=0)
mask_pred = net(img)[0]
mask_pred = (mask_pred > 0.5).float()
tot += dice_coeff(mask_pred, true_mask).item()
return tot / (i + 1)
tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item()
return tot / n_val

View file

@ -1,50 +1,38 @@
import argparse
import logging
import os
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from unet import UNet
from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf
from utils import plot_img_and_mask
from utils import resize_and_crop, normalize, hwc_to_chw, dense_crf
from torchvision import transforms
def predict_img(net,
full_img,
scale_factor=0.5,
device,
scale_factor=1,
out_threshold=0.5,
use_dense_crf=True,
use_gpu=False):
use_dense_crf=False):
net.eval()
img_height = full_img.size[1]
img_width = full_img.size[0]
img = resize_and_crop(full_img, scale=scale_factor)
img = normalize(img)
img = hwc_to_chw(img)
left_square, right_square = split_img_into_squares(img)
X = torch.from_numpy(img).unsqueeze(0)
left_square = hwc_to_chw(left_square)
right_square = hwc_to_chw(right_square)
X_left = torch.from_numpy(left_square).unsqueeze(0)
X_right = torch.from_numpy(right_square).unsqueeze(0)
if use_gpu:
X_left = X_left.cuda()
X_right = X_right.cuda()
X = X.to(device=device)
with torch.no_grad():
output_left = net(X_left)
output_right = net(X_right)
left_probs = output_left.squeeze(0)
right_probs = output_right.squeeze(0)
output = net(X)
probs = output.squeeze(0)
tf = transforms.Compose(
[
@ -54,13 +42,9 @@ def predict_img(net,
]
)
left_probs = tf(left_probs.cpu())
right_probs = tf(right_probs.cpu())
probs = tf(probs.cpu())
left_mask_np = left_probs.squeeze().cpu().numpy()
right_mask_np = right_probs.squeeze().cpu().numpy()
full_mask = merge_masks(left_mask_np, right_mask_np, img_width)
full_mask = probs.squeeze().cpu().numpy()
if use_dense_crf:
full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)
@ -68,30 +52,23 @@ def predict_img(net,
return full_mask > out_threshold
def get_args():
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(description='Predict masks from input images',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--model', '-m', default='MODEL.pth',
metavar='FILE',
help="Specify the file in which is stored the model"
" (default : 'MODEL.pth')")
help="Specify the file in which the model is stored")
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
help='filenames of input images', required=True)
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
help='filenames of ouput images')
parser.add_argument('--cpu', '-c', action='store_true',
help="Do not use the cuda version of the net",
default=False)
help='Filenames of ouput images')
parser.add_argument('--viz', '-v', action='store_true',
help="Visualize the images as they are processed",
default=False)
parser.add_argument('--no-save', '-n', action='store_true',
help="Do not save the output masks",
default=False)
parser.add_argument('--no-crf', '-r', action='store_true',
help="Do not use dense CRF postprocessing",
default=False)
parser.add_argument('--mask-threshold', '-t', type=float,
help="Minimum probability value to consider a mask pixel white",
default=0.5)
@ -101,6 +78,7 @@ def get_args():
return parser.parse_args()
def get_output_filenames(args):
in_files = args.input
out_files = []
@ -110,16 +88,18 @@ def get_output_filenames(args):
pathsplit = os.path.splitext(f)
out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1]))
elif len(in_files) != len(args.output):
print("Error : Input files and output files are not of the same length")
logging.error("Input files and output files are not of the same length")
raise SystemExit()
else:
out_files = args.output
return out_files
def mask_to_image(mask):
return Image.fromarray((mask * 255).astype(np.uint8))
if __name__ == "__main__":
args = get_args()
in_files = args.input
@ -127,40 +107,34 @@ if __name__ == "__main__":
net = UNet(n_channels=3, n_classes=1)
print("Loading model {}".format(args.model))
logging.info("Loading model {}".format(args.model))
if not args.cpu:
print("Using CUDA version of the net, prepare your GPU !")
net.cuda()
net.load_state_dict(torch.load(args.model))
else:
net.cpu()
net.load_state_dict(torch.load(args.model, map_location='cpu'))
print("Using CPU version of the net, this may be very slow")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
net.to(deviec=device)
net.load_state_dict(torch.load(args.model, map_location=device))
print("Model loaded !")
logging.info("Model loaded !")
for i, fn in enumerate(in_files):
print("\nPredicting image {} ...".format(fn))
logging.info("\nPredicting image {} ...".format(fn))
img = Image.open(fn)
if img.size[0] < img.size[1]:
print("Error: image height larger than the width")
mask = predict_img(net=net,
full_img=img,
scale_factor=args.scale,
out_threshold=args.mask_threshold,
use_dense_crf= not args.no_crf,
use_gpu=not args.cpu)
if args.viz:
print("Visualizing results for image {}, close to continue ...".format(fn))
plot_img_and_mask(img, mask)
use_dense_crf=False,
device=device)
if not args.no_save:
out_fn = out_files[i]
result = mask_to_image(mask)
result.save(out_files[i])
print("Mask saved to {}".format(out_files[i]))
logging.info("Mask saved to {}".format(out_files[i]))
if args.viz:
logging.info("Visualizing results for image {}, close to continue ...".format(fn))
plot_img_and_mask(img, mask)

View file

@ -1,11 +1,13 @@
""" Submit code specific to the kaggle challenge"""
import os
from PIL import Image
import torch
from PIL import Image
from predict import predict_img
from utils import rle_encode
from unet import UNet
from utils import rle_encode
def submit(net, gpu=False):

202
train.py
View file

@ -1,58 +1,52 @@
import sys
import argparse
import logging
import os
from optparse import OptionParser
import numpy as np
import sys
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch import optim
from tqdm import tqdm
from eval import eval_net
from unet import UNet
from utils import get_ids, split_ids, split_train_val, get_imgs_and_masks, batch
from utils import get_ids, split_train_val, get_imgs_and_masks, batch
dir_img = 'data/imgs/'
dir_mask = 'data/masks/'
dir_checkpoint = 'checkpoints/'
def train_net(net,
device,
epochs=5,
batch_size=1,
lr=0.1,
val_percent=0.05,
val_percent=0.15,
save_cp=True,
gpu=False,
img_scale=0.5):
dir_img = 'data/train/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoints/'
ids = get_ids(dir_img)
ids = split_ids(ids)
iddataset = split_train_val(ids, val_percent)
print('''
Starting training:
Epochs: {}
Batch size: {}
Learning rate: {}
Training size: {}
Validation size: {}
Checkpoints: {}
CUDA: {}
'''.format(epochs, batch_size, lr, len(iddataset['train']),
len(iddataset['val']), str(save_cp), str(gpu)))
N_train = len(iddataset['train'])
optimizer = optim.SGD(net.parameters(),
lr=lr,
momentum=0.9,
weight_decay=0.0005)
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {lr}
Training size: {len(iddataset["train"])}
Validation size: {len(iddataset["val"])}
Checkpoints: {save_cp}
Device: {device.type}
Images scaling: {img_scale}
''')
n_train = len(iddataset['train'])
n_val = len(iddataset['val'])
optimizer = optim.Adam(net.parameters(), lr=lr)
criterion = nn.BCELoss()
for epoch in range(epochs):
print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
net.train()
# reset the generators
@ -60,87 +54,111 @@ def train_net(net,
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale)
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
for i, b in enumerate(batch(train, batch_size)):
imgs = np.array([i[0] for i in b]).astype(np.float32)
true_masks = np.array([i[1] for i in b])
for i, b in enumerate(batch(train, batch_size)):
imgs = np.array([i[0] for i in b]).astype(np.float32)
true_masks = np.array([i[1] for i in b])
imgs = torch.from_numpy(imgs)
true_masks = torch.from_numpy(true_masks)
imgs = torch.from_numpy(imgs)
true_masks = torch.from_numpy(true_masks)
imgs = imgs.to(device=device)
true_masks = true_masks.to(device=device)
if gpu:
imgs = imgs.cuda()
true_masks = true_masks.cuda()
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item()
masks_pred = net(imgs)
masks_probs_flat = masks_pred.view(-1)
pbar.set_postfix(**{'loss (batch)': loss.item()})
true_masks_flat = true_masks.view(-1)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss = criterion(masks_probs_flat, true_masks_flat)
epoch_loss += loss.item()
print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch finished ! Loss: {}'.format(epoch_loss / i))
if 1:
val_dice = eval_net(net, val, gpu)
print('Validation Dice Coeff: {}'.format(val_dice))
pbar.update(batch_size)
if save_cp:
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(),
dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
print('Checkpoint {} saved !'.format(epoch + 1))
dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
logging.info(f'Checkpoint {epoch + 1} saved !')
val_dice = eval_net(net, val, device, n_val)
logging.info('Validation Dice Coeff: {}'.format(val_dice))
def get_args():
parser = OptionParser()
parser.add_option('-e', '--epochs', dest='epochs', default=5, type='int',
help='number of epochs')
parser.add_option('-b', '--batch-size', dest='batchsize', default=10,
type='int', help='batch size')
parser.add_option('-l', '--learning-rate', dest='lr', default=0.1,
type='float', help='learning rate')
parser.add_option('-g', '--gpu', action='store_true', dest='gpu',
default=False, help='use cuda')
parser.add_option('-c', '--load', dest='load',
default=False, help='load file model')
parser.add_option('-s', '--scale', dest='scale', type='float',
default=0.5, help='downscaling factor of the images')
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,
help='Number of epochs', dest='epochs')
parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,
help='Batch size', dest='batchsize')
parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1,
help='Learning rate', dest='lr')
parser.add_argument('-f', '--load', dest='load', type=str, default=False,
help='Load model from a .pth file')
parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5,
help='Downscaling factor of the images')
parser.add_argument('-v', '--validation', dest='val', type=float, default=15.0,
help='Percent of the data that is used as validation (0-100)')
return parser.parse_args()
def pretrain_checks():
imgs = [f for f in os.listdir(dir_img) if not f.startswith('.')]
masks = [f for f in os.listdir(dir_mask) if not f.startswith('.')]
if len(imgs) != len(masks):
logging.warning(f'The number of images and masks do not match ! '
f'{len(imgs)} images and {len(masks)} masks detected in the data folder.')
(options, args) = parser.parse_args()
return options
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
args = get_args()
pretrain_checks()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
# - For 1 class and background, use n_classes=1
# - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N
net = UNet(n_channels=3, n_classes=1)
logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n'
f'\t{"Bilinear" if net.bilinear else "Dilated conv"} upscaling')
if args.load:
net.load_state_dict(torch.load(args.load))
print('Model loaded from {}'.format(args.load))
net.load_state_dict(
torch.load(args.load, map_location=device)
)
logging.info(f'Model loaded from {args.load}')
if args.gpu:
net.cuda()
# cudnn.benchmark = True # faster convolutions, but more memory
net.to(device=device)
# faster convolutions, but more memory
# cudnn.benchmark = True
try:
train_net(net=net,
epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
try:
train_net(net=net,
epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
gpu=args.gpu,
img_scale=args.scale)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
print('Saved interrupt')
try:
sys.exit(0)
except SystemExit:
os._exit(0)
sys.exit(0)
except SystemExit:
os._exit(0)

View file

@ -1,22 +1,27 @@
# full assembly of the sub-parts to form the complete net
""" Full assembly of the parts to form the complete network """
import torch.nn.functional as F
from .unet_parts import *
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.inc = inconv(n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)
self.outc = outconv(64, n_classes)
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 512)
self.up1 = Up(1024, 256, bilinear)
self.up2 = Up(512, 128, bilinear)
self.up3 = Up(256, 64, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
@ -29,4 +34,7 @@ class UNet(nn.Module):
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return F.sigmoid(x)
if self.n_classes > 1:
return F.softmax(x, dim=1)
else:
return torch.sigmoid(x)

View file

@ -1,88 +1,75 @@
# sub-parts of the U-Net model
""" Parts of the U-Net model """
import torch
import torch.nn as nn
import torch.nn.functional as F
class double_conv(nn.Module):
'''(conv => BN => ReLU) * 2'''
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
return self.double_conv(x)
class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = double_conv(in_ch, out_ch)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def forward(self, x):
x = self.conv(x)
return x
class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
self.mpconv = nn.Sequential(
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch)
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
x = self.mpconv(x)
return x
return self.maxpool_conv(x)
class up(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True):
super(up, self).__init__()
class Up(nn.Module):
"""Upscaling then double conv"""
# would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
self.conv = double_conv(in_ch, out_ch)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2))
# for padding issues, see
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
return self.conv(x)
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
x = self.conv(x)
return x
return self.conv(x)

View file

@ -1,6 +1,7 @@
import numpy as np
import pydensecrf.densecrf as dcrf
def dense_crf(img, output_probs):
h = output_probs.shape[0]
w = output_probs.shape[1]

View file

@ -1,12 +1,17 @@
import matplotlib.pyplot as plt
def plot_img_and_mask(img, mask):
fig = plt.figure()
a = fig.add_subplot(1, 2, 1)
a.set_title('Input image')
plt.imshow(img)
b = fig.add_subplot(1, 2, 2)
b.set_title('Output mask')
plt.imshow(mask)
def plot_img_and_mask(img, mask):
classes = mask.shape[2] if len(mask.shape) > 2 else 1
fig, ax = plt.subplots(1, classes + 1)
ax[0].set_title('Input image')
ax[0].imshow(img)
if classes > 1:
for i in range(classes):
ax[i+1].set_title(f'Output mask (class {i+1})')
ax[i+1].imshow(mask[:, :, i])
else:
ax[1].set_title(f'Output mask')
ax[1].imshow(mask)
plt.xticks([]), plt.yticks([])
plt.show()

View file

@ -1,34 +1,27 @@
#
# load.py : utils on generators / lists of ids to transform from strings to
# cropped images and masks
""" Utils on generators / lists of ids to transform from strings to cropped images and masks """
import os
import numpy as np
from PIL import Image
from .utils import resize_and_crop, get_square, normalize, hwc_to_chw
from .utils import resize_and_crop, normalize, hwc_to_chw
def get_ids(dir):
"""Returns a list of the ids in the directory"""
return (f[:-4] for f in os.listdir(dir))
def split_ids(ids, n=2):
"""Split each id in n, creating n tuples (id, k) for each id"""
return ((id, i) for id in ids for i in range(n))
return (os.path.splitext(f)[0] for f in os.listdir(dir) if not f.startswith('.'))
def to_cropped_imgs(ids, dir, suffix, scale):
"""From a list of tuples, returns the correct cropped img"""
for id, pos in ids:
for id in ids:
im = resize_and_crop(Image.open(dir + id + suffix), scale=scale)
yield get_square(im, pos)
yield im
def get_imgs_and_masks(ids, dir_img, dir_mask, scale):
"""Return all the couples (img, mask)"""
imgs = to_cropped_imgs(ids, dir_img, '.jpg', scale)
# need to transform from HWC to CHW
@ -36,8 +29,9 @@ def get_imgs_and_masks(ids, dir_img, dir_mask, scale):
imgs_normalized = map(normalize, imgs_switched)
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif', scale)
masks_switched = map(hwc_to_chw, masks)
return zip(imgs_normalized, masks)
return zip(imgs_normalized, masks_switched)
def get_full_img_and_mask(id, dir_img, dir_mask):

View file

@ -1,21 +1,12 @@
import random
import numpy as np
def get_square(img, pos):
"""Extract a left or a right square from ndarray shape : (H, W, C))"""
h = img.shape[0]
if pos == 0:
return img[:, :h]
else:
return img[:, -h:]
def split_img_into_squares(img):
return get_square(img, 0), get_square(img, 1)
def hwc_to_chw(img):
return np.transpose(img, axes=[2, 0, 1])
def resize_and_crop(pilimg, scale=0.5, final_height=None):
w = pilimg.size[0]
h = pilimg.size[1]
@ -29,7 +20,11 @@ def resize_and_crop(pilimg, scale=0.5, final_height=None):
img = pilimg.resize((newW, newH))
img = img.crop((0, diff // 2, newW, newH - diff // 2))
return np.array(img, dtype=np.float32)
ar = np.array(img, dtype=np.float32)
if len(ar.shape) == 2:
# for greyscale images, add a new axis
ar = np.expand_dims(ar, axis=2)
return ar
def batch(iterable, batch_size):
"""Yields lists by batch"""
@ -43,6 +38,7 @@ def batch(iterable, batch_size):
if len(b) > 0:
yield b
def split_train_val(dataset, val_percent=0.05):
dataset = list(dataset)
length = len(dataset)
@ -54,15 +50,6 @@ def split_train_val(dataset, val_percent=0.05):
def normalize(x):
return x / 255
def merge_masks(img1, img2, full_w):
h = img1.shape[0]
new = np.zeros((h, full_w), np.float32)
new[:, :full_w // 2 + 1] = img1[:, :full_w // 2 + 1]
new[:, full_w // 2 + 1:] = img2[:, -(full_w // 2 - 1):]
return new
# credits to https://stackoverflow.com/users/6076729/manuel-lagunas
def rle_encode(mask_image):