From 4c0f0a7a7b8e200772d812b8731a999069bea5e0 Mon Sep 17 00:00:00 2001 From: milesial Date: Thu, 24 Oct 2019 21:37:21 +0200 Subject: [PATCH] Global cleanup, better logging and CLI Former-commit-id: ff1ac0936c118d129bc8a8014958948d3b3883be --- MODEL.pth.REMOVED.git-id | 1 - README.md | 74 +++++++++++--- data/imgs/.keep | 0 data/masks/.keep | 0 dice_loss.py | 3 +- eval.py | 20 ++-- predict.py | 96 +++++++------------ submit.py | 6 +- train.py | 202 +++++++++++++++++++++------------------ unet/unet_model.py | 34 ++++--- unet/unet_parts.py | 85 +++++++--------- utils/crf.py | 1 + utils/data_vis.py | 23 +++-- utils/load.py | 22 ++--- utils/utils.py | 29 ++---- 15 files changed, 311 insertions(+), 285 deletions(-) delete mode 100644 MODEL.pth.REMOVED.git-id create mode 100644 data/imgs/.keep create mode 100644 data/masks/.keep diff --git a/MODEL.pth.REMOVED.git-id b/MODEL.pth.REMOVED.git-id deleted file mode 100644 index 8f9fb71..0000000 --- a/MODEL.pth.REMOVED.git-id +++ /dev/null @@ -1 +0,0 @@ -408f675eb803bd50727626d588144df3f99e6234 \ No newline at end of file diff --git a/README.md b/README.md index cc8334c..5c82c44 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,13 @@ -# Pytorch-UNet +# UNet: semantic segmentation with PyTorch + ![input and output for a random image in the test dataset](https://framapic.org/OcE8HlU6me61/KNTt8GFQzxDR.png) -Customized implementation of the [U-Net](https://arxiv.org/pdf/1505.04597.pdf) in Pytorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge) from a high definition image. This was used with only one output class but it can be scaled easily. +Customized implementation of the [U-Net](https://arxiv.org/abs/1505.04597) in PyTorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge) from high definition images. -This model was trained from scratch with 5000 images (no data augmentation) and scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 (511 out of 735) on over 100k test images. This score is not quite good but could be improved with more training, data augmentation, fine tuning, playing with CRF post-processing, and applying more weights on the edges of the masks. +This model was trained from scratch with 5000 images (no data augmentation) and scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 (511 out of 735) on over 100k test images. This score could be improved with more training, data augmentation, fine tuning, playing with CRF post-processing, and applying more weights on the edges of the masks. -The model used for the last submission is stored in the `MODEL.pth` file, if you wish to play with it. The data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data). +The Carvana data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data). ## Usage **Note : Use Python 3** @@ -14,9 +15,6 @@ The model used for the last submission is stored in the `MODEL.pth` file, if you You can easily test the output masks on your images via the CLI. -To see all options: -`python predict.py -h` - To predict a single image and save it: `python predict.py -i image.jpg -o output.jpg` @@ -25,15 +23,61 @@ To predict a multiple images and show them without saving them: `python predict.py -i image1.jpg image2.jpg --viz --no-save` -You can use the cpu-only version with `--cpu`. +```shell script +> python predict.py -h +usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...] + [--output INPUT [INPUT ...]] [--viz] [--no-save] + [--mask-threshold MASK_THRESHOLD] [--scale SCALE] +Predict masks from input images + +optional arguments: + -h, --help show this help message and exit + --model FILE, -m FILE + Specify the file in which the model is stored + (default: MODEL.pth) + --input INPUT [INPUT ...], -i INPUT [INPUT ...] + filenames of input images (default: None) + --output INPUT [INPUT ...], -o INPUT [INPUT ...] + Filenames of ouput images (default: None) + --viz, -v Visualize the images as they are processed (default: + False) + --no-save, -n Do not save the output masks (default: False) + --mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD + Minimum probability value to consider a mask pixel + white (default: 0.5) + --scale SCALE, -s SCALE + Scale factor for the input images (default: 0.5) +``` You can specify which model file to use with `--model MODEL.pth`. ### Training -`python train.py -h` should get you started. A proper CLI is yet to be added. -## Warning -In order to process the image, it is split into two squares (a left on and a right one), and each square is passed into the net. The two square masks are then merged again to produce the final image. As a consequence, the height of the image must be strictly superior than half the width. Make sure the width is even too. +```shell script +> python train.py -h +usage: train.py [-h] [-e E] [-b [B]] [-l [LR]] [-f LOAD] [-s SCALE] [-v VAL] + +Train the UNet on images and target masks + +optional arguments: + -h, --help show this help message and exit + -e E, --epochs E Number of epochs (default: 5) + -b [B], --batch-size [B] + Batch size (default: 1) + -l [LR], --learning-rate [LR] + Learning rate (default: 0.1) + -f LOAD, --load LOAD Load model from a .pth file (default: False) + -s SCALE, --scale SCALE + Downscaling factor of the images (default: 0.5) + -v VAL, --validation VAL + Percent of the data that is used as validation (0-100) + (default: 15.0) + +``` +By default, the `scale` is 0.5, so if you wish to obtain better results (but use more memory), set it to 1. + +The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively. + ## Dependencies This package depends on [pydensecrf](https://github.com/lucasb-eyer/pydensecrf), available via `pip install`. @@ -42,5 +86,11 @@ This package depends on [pydensecrf](https://github.com/lucasb-eyer/pydensecrf), The model has be trained from scratch on a GTX970M 3GB. Predicting images of 1918*1280 takes 1.5GB of memory. -Training takes approximately 3GB, so if you are a few MB shy of memory, consider turning off all graphical displays. +Training takes much approximately 3GB, so if you are a few MB shy of memory, consider turning off all graphical displays. This assumes you use bilinear up-sampling, and not transposed convolution in the model. + +--- + +Original paper by Olaf Ronneberger, Philipp Fischer, Thomas Brox: [https://arxiv.org/abs/1505.04597](https://arxiv.org/abs/1505.04597) + +![network architecture](https://i.imgur.com/jeDVpqF.png) diff --git a/data/imgs/.keep b/data/imgs/.keep new file mode 100644 index 0000000..e69de29 diff --git a/data/masks/.keep b/data/masks/.keep new file mode 100644 index 0000000..e69de29 diff --git a/dice_loss.py b/dice_loss.py index 71edf6a..fe86611 100644 --- a/dice_loss.py +++ b/dice_loss.py @@ -1,5 +1,6 @@ import torch -from torch.autograd import Function, Variable +from torch.autograd import Function + class DiceCoeff(Function): """Dice coeff for individual examples""" diff --git a/eval.py b/eval.py index c727510..2944bea 100644 --- a/eval.py +++ b/eval.py @@ -1,26 +1,26 @@ import torch -import torch.nn.functional as F +from tqdm import tqdm from dice_loss import dice_coeff -def eval_net(net, dataset, gpu=False): +def eval_net(net, dataset, device, n_val): """Evaluation without the densecrf with the dice coefficient""" net.eval() tot = 0 - for i, b in enumerate(dataset): + + for i, b in tqdm(enumerate(dataset), total=n_val, desc='Validation round', unit='img'): img = b[0] true_mask = b[1] img = torch.from_numpy(img).unsqueeze(0) true_mask = torch.from_numpy(true_mask).unsqueeze(0) - if gpu: - img = img.cuda() - true_mask = true_mask.cuda() + img = img.to(device=device) + true_mask = true_mask.to(device=device) + + mask_pred = net(img).squeeze(dim=0) - mask_pred = net(img)[0] mask_pred = (mask_pred > 0.5).float() - - tot += dice_coeff(mask_pred, true_mask).item() - return tot / (i + 1) + tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item() + return tot / n_val diff --git a/predict.py b/predict.py index 8b08847..cf6f1a3 100755 --- a/predict.py +++ b/predict.py @@ -1,50 +1,38 @@ import argparse +import logging import os import numpy as np import torch -import torch.nn.functional as F - from PIL import Image +from torchvision import transforms from unet import UNet -from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf from utils import plot_img_and_mask +from utils import resize_and_crop, normalize, hwc_to_chw, dense_crf -from torchvision import transforms def predict_img(net, full_img, - scale_factor=0.5, + device, + scale_factor=1, out_threshold=0.5, - use_dense_crf=True, - use_gpu=False): - + use_dense_crf=False): net.eval() img_height = full_img.size[1] img_width = full_img.size[0] img = resize_and_crop(full_img, scale=scale_factor) img = normalize(img) + img = hwc_to_chw(img) - left_square, right_square = split_img_into_squares(img) + X = torch.from_numpy(img).unsqueeze(0) - left_square = hwc_to_chw(left_square) - right_square = hwc_to_chw(right_square) - - X_left = torch.from_numpy(left_square).unsqueeze(0) - X_right = torch.from_numpy(right_square).unsqueeze(0) - - if use_gpu: - X_left = X_left.cuda() - X_right = X_right.cuda() + X = X.to(device=device) with torch.no_grad(): - output_left = net(X_left) - output_right = net(X_right) - - left_probs = output_left.squeeze(0) - right_probs = output_right.squeeze(0) + output = net(X) + probs = output.squeeze(0) tf = transforms.Compose( [ @@ -53,14 +41,10 @@ def predict_img(net, transforms.ToTensor() ] ) - - left_probs = tf(left_probs.cpu()) - right_probs = tf(right_probs.cpu()) - left_mask_np = left_probs.squeeze().cpu().numpy() - right_mask_np = right_probs.squeeze().cpu().numpy() + probs = tf(probs.cpu()) - full_mask = merge_masks(left_mask_np, right_mask_np, img_width) + full_mask = probs.squeeze().cpu().numpy() if use_dense_crf: full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask) @@ -68,30 +52,23 @@ def predict_img(net, return full_mask > out_threshold - def get_args(): - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser(description='Predict masks from input images', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE', - help="Specify the file in which is stored the model" - " (default : 'MODEL.pth')") + help="Specify the file in which the model is stored") parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='filenames of input images', required=True) parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', - help='filenames of ouput images') - parser.add_argument('--cpu', '-c', action='store_true', - help="Do not use the cuda version of the net", - default=False) + help='Filenames of ouput images') parser.add_argument('--viz', '-v', action='store_true', help="Visualize the images as they are processed", default=False) parser.add_argument('--no-save', '-n', action='store_true', help="Do not save the output masks", default=False) - parser.add_argument('--no-crf', '-r', action='store_true', - help="Do not use dense CRF postprocessing", - default=False) parser.add_argument('--mask-threshold', '-t', type=float, help="Minimum probability value to consider a mask pixel white", default=0.5) @@ -101,6 +78,7 @@ def get_args(): return parser.parse_args() + def get_output_filenames(args): in_files = args.input out_files = [] @@ -110,16 +88,18 @@ def get_output_filenames(args): pathsplit = os.path.splitext(f) out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1])) elif len(in_files) != len(args.output): - print("Error : Input files and output files are not of the same length") + logging.error("Input files and output files are not of the same length") raise SystemExit() else: out_files = args.output return out_files + def mask_to_image(mask): return Image.fromarray((mask * 255).astype(np.uint8)) + if __name__ == "__main__": args = get_args() in_files = args.input @@ -127,40 +107,34 @@ if __name__ == "__main__": net = UNet(n_channels=3, n_classes=1) - print("Loading model {}".format(args.model)) + logging.info("Loading model {}".format(args.model)) - if not args.cpu: - print("Using CUDA version of the net, prepare your GPU !") - net.cuda() - net.load_state_dict(torch.load(args.model)) - else: - net.cpu() - net.load_state_dict(torch.load(args.model, map_location='cpu')) - print("Using CPU version of the net, this may be very slow") + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + logging.info(f'Using device {device}') + net.to(deviec=device) + net.load_state_dict(torch.load(args.model, map_location=device)) - print("Model loaded !") + logging.info("Model loaded !") for i, fn in enumerate(in_files): - print("\nPredicting image {} ...".format(fn)) + logging.info("\nPredicting image {} ...".format(fn)) img = Image.open(fn) - if img.size[0] < img.size[1]: - print("Error: image height larger than the width") mask = predict_img(net=net, full_img=img, scale_factor=args.scale, out_threshold=args.mask_threshold, - use_dense_crf= not args.no_crf, - use_gpu=not args.cpu) - - if args.viz: - print("Visualizing results for image {}, close to continue ...".format(fn)) - plot_img_and_mask(img, mask) + use_dense_crf=False, + device=device) if not args.no_save: out_fn = out_files[i] result = mask_to_image(mask) result.save(out_files[i]) - print("Mask saved to {}".format(out_files[i])) + logging.info("Mask saved to {}".format(out_files[i])) + + if args.viz: + logging.info("Visualizing results for image {}, close to continue ...".format(fn)) + plot_img_and_mask(img, mask) diff --git a/submit.py b/submit.py index 12f26cf..49acd88 100644 --- a/submit.py +++ b/submit.py @@ -1,11 +1,13 @@ +""" Submit code specific to the kaggle challenge""" + import os -from PIL import Image import torch +from PIL import Image from predict import predict_img -from utils import rle_encode from unet import UNet +from utils import rle_encode def submit(net, gpu=False): diff --git a/train.py b/train.py index cfdeb3b..379a683 100644 --- a/train.py +++ b/train.py @@ -1,58 +1,52 @@ -import sys +import argparse +import logging import os -from optparse import OptionParser -import numpy as np +import sys +import numpy as np import torch -import torch.backends.cudnn as cudnn import torch.nn as nn from torch import optim +from tqdm import tqdm from eval import eval_net from unet import UNet -from utils import get_ids, split_ids, split_train_val, get_imgs_and_masks, batch +from utils import get_ids, split_train_val, get_imgs_and_masks, batch + +dir_img = 'data/imgs/' +dir_mask = 'data/masks/' +dir_checkpoint = 'checkpoints/' + def train_net(net, + device, epochs=5, batch_size=1, lr=0.1, - val_percent=0.05, + val_percent=0.15, save_cp=True, - gpu=False, img_scale=0.5): - - dir_img = 'data/train/' - dir_mask = 'data/train_masks/' - dir_checkpoint = 'checkpoints/' - ids = get_ids(dir_img) - ids = split_ids(ids) iddataset = split_train_val(ids, val_percent) - print(''' - Starting training: - Epochs: {} - Batch size: {} - Learning rate: {} - Training size: {} - Validation size: {} - Checkpoints: {} - CUDA: {} - '''.format(epochs, batch_size, lr, len(iddataset['train']), - len(iddataset['val']), str(save_cp), str(gpu))) - - N_train = len(iddataset['train']) - - optimizer = optim.SGD(net.parameters(), - lr=lr, - momentum=0.9, - weight_decay=0.0005) + logging.info(f'''Starting training: + Epochs: {epochs} + Batch size: {batch_size} + Learning rate: {lr} + Training size: {len(iddataset["train"])} + Validation size: {len(iddataset["val"])} + Checkpoints: {save_cp} + Device: {device.type} + Images scaling: {img_scale} + ''') + n_train = len(iddataset['train']) + n_val = len(iddataset['val']) + optimizer = optim.Adam(net.parameters(), lr=lr) criterion = nn.BCELoss() for epoch in range(epochs): - print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) net.train() # reset the generators @@ -60,87 +54,111 @@ def train_net(net, val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale) epoch_loss = 0 + with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: + for i, b in enumerate(batch(train, batch_size)): + imgs = np.array([i[0] for i in b]).astype(np.float32) + true_masks = np.array([i[1] for i in b]) - for i, b in enumerate(batch(train, batch_size)): - imgs = np.array([i[0] for i in b]).astype(np.float32) - true_masks = np.array([i[1] for i in b]) + imgs = torch.from_numpy(imgs) + true_masks = torch.from_numpy(true_masks) - imgs = torch.from_numpy(imgs) - true_masks = torch.from_numpy(true_masks) + imgs = imgs.to(device=device) + true_masks = true_masks.to(device=device) - if gpu: - imgs = imgs.cuda() - true_masks = true_masks.cuda() + masks_pred = net(imgs) + loss = criterion(masks_pred, true_masks) + epoch_loss += loss.item() - masks_pred = net(imgs) - masks_probs_flat = masks_pred.view(-1) + pbar.set_postfix(**{'loss (batch)': loss.item()}) - true_masks_flat = true_masks.view(-1) + optimizer.zero_grad() + loss.backward() + optimizer.step() - loss = criterion(masks_probs_flat, true_masks_flat) - epoch_loss += loss.item() - - print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item())) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - print('Epoch finished ! Loss: {}'.format(epoch_loss / i)) - - if 1: - val_dice = eval_net(net, val, gpu) - print('Validation Dice Coeff: {}'.format(val_dice)) + pbar.update(batch_size) if save_cp: + try: + os.mkdir(dir_checkpoint) + logging.info('Created checkpoint directory') + except OSError: + pass torch.save(net.state_dict(), - dir_checkpoint + 'CP{}.pth'.format(epoch + 1)) - print('Checkpoint {} saved !'.format(epoch + 1)) + dir_checkpoint + f'CP_epoch{epoch + 1}.pth') + logging.info(f'Checkpoint {epoch + 1} saved !') + val_dice = eval_net(net, val, device, n_val) + logging.info('Validation Dice Coeff: {}'.format(val_dice)) def get_args(): - parser = OptionParser() - parser.add_option('-e', '--epochs', dest='epochs', default=5, type='int', - help='number of epochs') - parser.add_option('-b', '--batch-size', dest='batchsize', default=10, - type='int', help='batch size') - parser.add_option('-l', '--learning-rate', dest='lr', default=0.1, - type='float', help='learning rate') - parser.add_option('-g', '--gpu', action='store_true', dest='gpu', - default=False, help='use cuda') - parser.add_option('-c', '--load', dest='load', - default=False, help='load file model') - parser.add_option('-s', '--scale', dest='scale', type='float', - default=0.5, help='downscaling factor of the images') + parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5, + help='Number of epochs', dest='epochs') + parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1, + help='Batch size', dest='batchsize') + parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1, + help='Learning rate', dest='lr') + parser.add_argument('-f', '--load', dest='load', type=str, default=False, + help='Load model from a .pth file') + parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5, + help='Downscaling factor of the images') + parser.add_argument('-v', '--validation', dest='val', type=float, default=15.0, + help='Percent of the data that is used as validation (0-100)') + + return parser.parse_args() + + +def pretrain_checks(): + imgs = [f for f in os.listdir(dir_img) if not f.startswith('.')] + masks = [f for f in os.listdir(dir_mask) if not f.startswith('.')] + if len(imgs) != len(masks): + logging.warning(f'The number of images and masks do not match ! ' + f'{len(imgs)} images and {len(masks)} masks detected in the data folder.') - (options, args) = parser.parse_args() - return options if __name__ == '__main__': + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') args = get_args() + pretrain_checks() + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + logging.info(f'Using device {device}') + # Change here to adapt to your data + # n_channels=3 for RGB images + # n_classes is the number of probabilities you want to get per pixel + # - For 1 class and background, use n_classes=1 + # - For 2 classes, use n_classes=1 + # - For N > 2 classes, use n_classes=N net = UNet(n_channels=3, n_classes=1) + logging.info(f'Network:\n' + f'\t{net.n_channels} input channels\n' + f'\t{net.n_classes} output channels (classes)\n' + f'\t{"Bilinear" if net.bilinear else "Dilated conv"} upscaling') if args.load: - net.load_state_dict(torch.load(args.load)) - print('Model loaded from {}'.format(args.load)) + net.load_state_dict( + torch.load(args.load, map_location=device) + ) + logging.info(f'Model loaded from {args.load}') - if args.gpu: - net.cuda() - # cudnn.benchmark = True # faster convolutions, but more memory + net.to(device=device) + # faster convolutions, but more memory + # cudnn.benchmark = True +try: + train_net(net=net, + epochs=args.epochs, + batch_size=args.batchsize, + lr=args.lr, + device=device, + img_scale=args.scale, + val_percent=args.val / 100) +except KeyboardInterrupt: + torch.save(net.state_dict(), 'INTERRUPTED.pth') + logging.info('Saved interrupt') try: - train_net(net=net, - epochs=args.epochs, - batch_size=args.batchsize, - lr=args.lr, - gpu=args.gpu, - img_scale=args.scale) - except KeyboardInterrupt: - torch.save(net.state_dict(), 'INTERRUPTED.pth') - print('Saved interrupt') - try: - sys.exit(0) - except SystemExit: - os._exit(0) + sys.exit(0) + except SystemExit: + os._exit(0) diff --git a/unet/unet_model.py b/unet/unet_model.py index 5990649..466222a 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -1,22 +1,27 @@ -# full assembly of the sub-parts to form the complete net +""" Full assembly of the parts to form the complete network """ import torch.nn.functional as F from .unet_parts import * + class UNet(nn.Module): - def __init__(self, n_channels, n_classes): + def __init__(self, n_channels, n_classes, bilinear=True): super(UNet, self).__init__() - self.inc = inconv(n_channels, 64) - self.down1 = down(64, 128) - self.down2 = down(128, 256) - self.down3 = down(256, 512) - self.down4 = down(512, 512) - self.up1 = up(1024, 256) - self.up2 = up(512, 128) - self.up3 = up(256, 64) - self.up4 = up(128, 64) - self.outc = outconv(64, n_classes) + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + + self.inc = DoubleConv(n_channels, 64) + self.down1 = Down(64, 128) + self.down2 = Down(128, 256) + self.down3 = Down(256, 512) + self.down4 = Down(512, 512) + self.up1 = Up(1024, 256, bilinear) + self.up2 = Up(512, 128, bilinear) + self.up3 = Up(256, 64, bilinear) + self.up4 = Up(128, 64, bilinear) + self.outc = OutConv(64, n_classes) def forward(self, x): x1 = self.inc(x) @@ -29,4 +34,7 @@ class UNet(nn.Module): x = self.up3(x, x2) x = self.up4(x, x1) x = self.outc(x) - return F.sigmoid(x) + if self.n_classes > 1: + return F.softmax(x, dim=1) + else: + return torch.sigmoid(x) diff --git a/unet/unet_parts.py b/unet/unet_parts.py index b24e375..da9c68d 100644 --- a/unet/unet_parts.py +++ b/unet/unet_parts.py @@ -1,88 +1,75 @@ -# sub-parts of the U-Net model +""" Parts of the U-Net model """ import torch import torch.nn as nn import torch.nn.functional as F -class double_conv(nn.Module): - '''(conv => BN => ReLU) * 2''' - def __init__(self, in_ch, out_ch): - super(double_conv, self).__init__() - self.conv = nn.Sequential( - nn.Conv2d(in_ch, out_ch, 3, padding=1), - nn.BatchNorm2d(out_ch), +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), - nn.Conv2d(out_ch, out_ch, 3, padding=1), - nn.BatchNorm2d(out_ch), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): - x = self.conv(x) - return x + return self.double_conv(x) -class inconv(nn.Module): - def __init__(self, in_ch, out_ch): - super(inconv, self).__init__() - self.conv = double_conv(in_ch, out_ch) +class Down(nn.Module): + """Downscaling with maxpool then double conv""" - def forward(self, x): - x = self.conv(x) - return x - - -class down(nn.Module): - def __init__(self, in_ch, out_ch): - super(down, self).__init__() - self.mpconv = nn.Sequential( + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), - double_conv(in_ch, out_ch) + DoubleConv(in_channels, out_channels) ) def forward(self, x): - x = self.mpconv(x) - return x + return self.maxpool_conv(x) -class up(nn.Module): - def __init__(self, in_ch, out_ch, bilinear=True): - super(up, self).__init__() +class Up(nn.Module): + """Upscaling then double conv""" - # would be a nice idea if the upsampling could be learned too, - # but my machine do not have enough memory to handle all those weights + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) else: - self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) + self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2) - self.conv = double_conv(in_ch, out_ch) + self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) - # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] - x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, - diffY // 2, diffY - diffY//2)) - - # for padding issues, see + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd - x = torch.cat([x2, x1], dim=1) - x = self.conv(x) - return x + return self.conv(x) -class outconv(nn.Module): - def __init__(self, in_ch, out_ch): - super(outconv, self).__init__() - self.conv = nn.Conv2d(in_ch, out_ch, 1) +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): - x = self.conv(x) - return x + return self.conv(x) diff --git a/utils/crf.py b/utils/crf.py index 8a79953..5ee718f 100644 --- a/utils/crf.py +++ b/utils/crf.py @@ -1,6 +1,7 @@ import numpy as np import pydensecrf.densecrf as dcrf + def dense_crf(img, output_probs): h = output_probs.shape[0] w = output_probs.shape[1] diff --git a/utils/data_vis.py b/utils/data_vis.py index 4ec2f60..95b9130 100644 --- a/utils/data_vis.py +++ b/utils/data_vis.py @@ -1,12 +1,17 @@ import matplotlib.pyplot as plt -def plot_img_and_mask(img, mask): - fig = plt.figure() - a = fig.add_subplot(1, 2, 1) - a.set_title('Input image') - plt.imshow(img) - b = fig.add_subplot(1, 2, 2) - b.set_title('Output mask') - plt.imshow(mask) - plt.show() \ No newline at end of file +def plot_img_and_mask(img, mask): + classes = mask.shape[2] if len(mask.shape) > 2 else 1 + fig, ax = plt.subplots(1, classes + 1) + ax[0].set_title('Input image') + ax[0].imshow(img) + if classes > 1: + for i in range(classes): + ax[i+1].set_title(f'Output mask (class {i+1})') + ax[i+1].imshow(mask[:, :, i]) + else: + ax[1].set_title(f'Output mask') + ax[1].imshow(mask) + plt.xticks([]), plt.yticks([]) + plt.show() diff --git a/utils/load.py b/utils/load.py index e5670e0..306fd54 100644 --- a/utils/load.py +++ b/utils/load.py @@ -1,34 +1,27 @@ -# -# load.py : utils on generators / lists of ids to transform from strings to -# cropped images and masks +""" Utils on generators / lists of ids to transform from strings to cropped images and masks """ import os import numpy as np from PIL import Image -from .utils import resize_and_crop, get_square, normalize, hwc_to_chw +from .utils import resize_and_crop, normalize, hwc_to_chw def get_ids(dir): """Returns a list of the ids in the directory""" - return (f[:-4] for f in os.listdir(dir)) - - -def split_ids(ids, n=2): - """Split each id in n, creating n tuples (id, k) for each id""" - return ((id, i) for id in ids for i in range(n)) + return (os.path.splitext(f)[0] for f in os.listdir(dir) if not f.startswith('.')) def to_cropped_imgs(ids, dir, suffix, scale): """From a list of tuples, returns the correct cropped img""" - for id, pos in ids: + for id in ids: im = resize_and_crop(Image.open(dir + id + suffix), scale=scale) - yield get_square(im, pos) + yield im + def get_imgs_and_masks(ids, dir_img, dir_mask, scale): """Return all the couples (img, mask)""" - imgs = to_cropped_imgs(ids, dir_img, '.jpg', scale) # need to transform from HWC to CHW @@ -36,8 +29,9 @@ def get_imgs_and_masks(ids, dir_img, dir_mask, scale): imgs_normalized = map(normalize, imgs_switched) masks = to_cropped_imgs(ids, dir_mask, '_mask.gif', scale) + masks_switched = map(hwc_to_chw, masks) - return zip(imgs_normalized, masks) + return zip(imgs_normalized, masks_switched) def get_full_img_and_mask(id, dir_img, dir_mask): diff --git a/utils/utils.py b/utils/utils.py index 830f1ce..9066edd 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,21 +1,12 @@ import random + import numpy as np -def get_square(img, pos): - """Extract a left or a right square from ndarray shape : (H, W, C))""" - h = img.shape[0] - if pos == 0: - return img[:, :h] - else: - return img[:, -h:] - -def split_img_into_squares(img): - return get_square(img, 0), get_square(img, 1) - def hwc_to_chw(img): return np.transpose(img, axes=[2, 0, 1]) + def resize_and_crop(pilimg, scale=0.5, final_height=None): w = pilimg.size[0] h = pilimg.size[1] @@ -29,7 +20,11 @@ def resize_and_crop(pilimg, scale=0.5, final_height=None): img = pilimg.resize((newW, newH)) img = img.crop((0, diff // 2, newW, newH - diff // 2)) - return np.array(img, dtype=np.float32) + ar = np.array(img, dtype=np.float32) + if len(ar.shape) == 2: + # for greyscale images, add a new axis + ar = np.expand_dims(ar, axis=2) + return ar def batch(iterable, batch_size): """Yields lists by batch""" @@ -43,6 +38,7 @@ def batch(iterable, batch_size): if len(b) > 0: yield b + def split_train_val(dataset, val_percent=0.05): dataset = list(dataset) length = len(dataset) @@ -54,15 +50,6 @@ def split_train_val(dataset, val_percent=0.05): def normalize(x): return x / 255 -def merge_masks(img1, img2, full_w): - h = img1.shape[0] - - new = np.zeros((h, full_w), np.float32) - new[:, :full_w // 2 + 1] = img1[:, :full_w // 2 + 1] - new[:, full_w // 2 + 1:] = img2[:, -(full_w // 2 - 1):] - - return new - # credits to https://stackoverflow.com/users/6076729/manuel-lagunas def rle_encode(mask_image):