Find a file
milesial a7270ca699 README update
Former-commit-id: c551c91786f26acc93abed0c043115ab8ef2fce0
2021-08-16 06:04:25 +02:00
data Global cleanup, better logging and CLI 2019-10-26 23:17:48 +02:00
unet Summer cleanup 2021-08-16 02:53:00 +02:00
.gitignore Summer cleanup 2021-08-16 02:53:00 +02:00
data_loading.py Summer cleanup 2021-08-16 02:53:00 +02:00
dice_score.py Summer cleanup 2021-08-16 02:53:00 +02:00
Dockerfile Summer cleanup 2021-08-16 02:53:00 +02:00
download_data.sh Data download script 2021-08-16 03:18:09 +02:00
evaluate.py Summer cleanup 2021-08-16 02:53:00 +02:00
hubconf.py Torch hub 2020-07-29 18:50:33 -07:00
LICENSE Create LICENSE 2017-11-30 08:23:15 +01:00
predict.py README update 2021-08-16 06:04:25 +02:00
README.md README update 2021-08-16 06:04:25 +02:00
requirements.txt Summer cleanup 2021-08-16 02:53:00 +02:00
train.py Summer cleanup 2021-08-16 02:53:00 +02:00
utils.py Summer cleanup 2021-08-16 02:53:00 +02:00

U-Net: Semantic segmentation with PyTorch

input and output for a random image in the test dataset

Customized implementation of the U-Net in PyTorch for Kaggle's Carvana Image Masking Challenge from high definition images.

This model was trained from scratch with 5000 images (no data augmentation) and scored a dice coefficient of 0.988423 (511 out of 735) on over 100k test images. This score could be improved with more training, data augmentation, fine tuning, playing with CRF post-processing, and applying more weights on the edges of the masks.

Usage

Note : Use Python 3.6 or newer

Docker

A docker image containing the code and the dependencies is available on DockerHub. You can jump in the container with (docker >=19.03):

docker run -it --rm --gpus all milesial/unet

Training

> python train.py -h
usage: train.py [-h] [--epochs E] [--batch-size B] [--learning-rate LR]
                [--load LOAD] [--scale SCALE] [--validation VAL] [--amp]

Train the UNet on images and target masks

optional arguments:
  -h, --help            show this help message and exit
  --epochs E, -e E      Number of epochs
  --batch-size B, -b B  Batch size
  --learning-rate LR, -l LR
                        Learning rate
  --load LOAD, -f LOAD  Load model from a .pth file
  --scale SCALE, -s SCALE
                        Downscaling factor of the images
  --validation VAL, -v VAL
                        Percent of the data that is used as validation (0-100)
  --amp                 Use mixed precision

By default, the scale is 0.5, so if you wish to obtain better results (but use more memory), set it to 1. The input images and target masks should be in the data/imgs and data/masks folders respectively. For Carvana, images are RGB and masks are black and white.

Prediction

After training your model and saving it to MODEL.pth, you can easily test the output masks on your images via the CLI.

To predict a single image and save it:

python predict.py -i image.jpg -o output.jpg

To predict a multiple images and show them without saving them:

python predict.py -i image1.jpg image2.jpg --viz --no-save

> python predict.py -h
usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...] 
                  [--output INPUT [INPUT ...]] [--viz] [--no-save]
                  [--mask-threshold MASK_THRESHOLD] [--scale SCALE]

Predict masks from input images

optional arguments:
  -h, --help            show this help message and exit
  --model FILE, -m FILE
                        Specify the file in which the model is stored
  --input INPUT [INPUT ...], -i INPUT [INPUT ...]
                        Filenames of input images
  --output INPUT [INPUT ...], -o INPUT [INPUT ...]
                        Filenames of output images
  --viz, -v             Visualize the images as they are processed
  --no-save, -n         Do not save the output masks
  --mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD
                        Minimum probability value to consider a mask pixel white
  --scale SCALE, -s SCALE
                        Scale factor for the input images

You can specify which model file to use with --model MODEL.pth.

Weights & Biases

The training progress can be visualized in real-time using Weights & Biases. Loss curves, validation curves, weights and gradient histograms, as well as predicted masks are logged to the platform.

When launching a training, a link will be printed in the console. Click on it to go to your dashboard. If you have an existing W&B account, you can link it by setting the WANDB_API_KEY environment variable.

Pretrained model

A pretrained model is available for the Carvana dataset. It can also be loaded from torch.hub:

net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana')

The training was done with a 100% scale and bilinear upsampling.

Data

The Carvana data is available on the Kaggle website.

You can also download it using your Kaggle API key with:

bash download_data.sh <username> <apikey>

Notes on memory

The model has be trained from scratch on a GTX970M 3GB. Predicting images of 1918*1280 takes 1.5GB of memory. Training takes much approximately 3GB, so if you are a few MB shy of memory, consider turning off all graphical displays. This assumes you use bilinear up-sampling, and not transposed convolution in the model.

Convergence

See a reference training run with the Caravana dataset on TensorBoard.dev (only scalars are shown currently).


Original paper by Olaf Ronneberger, Philipp Fischer, Thomas Brox: https://arxiv.org/abs/1505.04597

network architecture