2019-10-24 19:37:21 +00:00
# UNet: semantic segmentation with PyTorch
2020-07-24 00:04:38 +00:00
[![xscode ](https://img.shields.io/badge/Available%20on-xs%3Acode-blue?style=?style=plastic&logo=appveyor&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAMAAACdt4HsAAAAGXRFWHRTb2Z0d2FyZQBBZG9iZSBJbWFnZVJlYWR5ccllPAAAAAZQTFRF////////VXz1bAAAAAJ0Uk5T/wDltzBKAAAAlUlEQVR42uzXSwqAMAwE0Mn9L+3Ggtgkk35QwcnSJo9S+yGwM9DCooCbgn4YrJ4CIPUcQF7/XSBbx2TEz4sAZ2q1RAECBAiYBlCtvwN+KiYAlG7UDGj59MViT9hOwEqAhYCtAsUZvL6I6W8c2wcbd+LIWSCHSTeSAAECngN4xxIDSK9f4B9t377Wd7H5Nt7/Xz8eAgwAvesLRjYYPuUAAAAASUVORK5CYII= )](https://xscode.com/milesial/Pytorch-UNet)
2017-11-30 07:30:38 +00:00
![input and output for a random image in the test dataset ](https://framapic.org/OcE8HlU6me61/KNTt8GFQzxDR.png )
2017-11-30 06:44:34 +00:00
2019-10-24 19:37:21 +00:00
Customized implementation of the [U-Net ](https://arxiv.org/abs/1505.04597 ) in PyTorch for Kaggle's [Carvana Image Masking Challenge ](https://www.kaggle.com/c/carvana-image-masking-challenge ) from high definition images.
2017-11-30 02:44:29 +00:00
2019-10-24 19:37:21 +00:00
This model was trained from scratch with 5000 images (no data augmentation) and scored a [dice coefficient ](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient ) of 0.988423 (511 out of 735) on over 100k test images. This score could be improved with more training, data augmentation, fine tuning, playing with CRF post-processing, and applying more weights on the edges of the masks.
2017-11-30 02:44:29 +00:00
2019-10-24 19:37:21 +00:00
The Carvana data is available on the [Kaggle website ](https://www.kaggle.com/c/carvana-image-masking-challenge/data ).
2017-11-30 02:44:29 +00:00
2017-11-30 05:45:19 +00:00
## Usage
2020-05-22 06:34:03 +00:00
**Note : Use Python 3.6 or newer**
2017-11-30 06:44:34 +00:00
### Prediction
2017-11-30 02:44:29 +00:00
2020-06-01 18:20:34 +00:00
After training your model and saving it to MODEL.pth, you can easily test the output masks on your images via the CLI.
2017-11-30 07:30:38 +00:00
2017-11-30 05:45:19 +00:00
To predict a single image and save it:
2017-11-30 06:19:52 +00:00
2018-06-08 17:28:46 +00:00
`python predict.py -i image.jpg -o output.jpg`
2017-11-30 05:45:19 +00:00
To predict a multiple images and show them without saving them:
2017-11-30 06:19:52 +00:00
2017-11-30 05:45:19 +00:00
`python predict.py -i image1.jpg image2.jpg --viz --no-save`
2019-10-24 19:37:21 +00:00
```shell script
> python predict.py -h
usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...]
[--output INPUT [INPUT ...]] [--viz] [--no-save]
[--mask-threshold MASK_THRESHOLD] [--scale SCALE]
Predict masks from input images
optional arguments:
-h, --help show this help message and exit
--model FILE, -m FILE
Specify the file in which the model is stored
(default: MODEL.pth)
--input INPUT [INPUT ...], -i INPUT [INPUT ...]
filenames of input images (default: None)
--output INPUT [INPUT ...], -o INPUT [INPUT ...]
Filenames of ouput images (default: None)
--viz, -v Visualize the images as they are processed (default:
False)
--no-save, -n Do not save the output masks (default: False)
--mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD
Minimum probability value to consider a mask pixel
white (default: 0.5)
--scale SCALE, -s SCALE
Scale factor for the input images (default: 0.5)
```
2017-11-30 05:45:19 +00:00
You can specify which model file to use with `--model MODEL.pth` .
2017-11-30 02:44:29 +00:00
2017-11-30 07:30:38 +00:00
### Training
2019-10-24 19:37:21 +00:00
```shell script
> python train.py -h
usage: train.py [-h] [-e E] [-b [B]] [-l [LR]] [-f LOAD] [-s SCALE] [-v VAL]
Train the UNet on images and target masks
optional arguments:
-h, --help show this help message and exit
-e E, --epochs E Number of epochs (default: 5)
-b [B], --batch-size [B]
Batch size (default: 1)
-l [LR], --learning-rate [LR]
Learning rate (default: 0.1)
-f LOAD, --load LOAD Load model from a .pth file (default: False)
-s SCALE, --scale SCALE
Downscaling factor of the images (default: 0.5)
-v VAL, --validation VAL
Percent of the data that is used as validation (0-100)
(default: 15.0)
```
By default, the `scale` is 0.5, so if you wish to obtain better results (but use more memory), set it to 1.
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively.
2020-08-12 07:42:01 +00:00
### Pretrained model
A [pretrained model ](https://github.com/milesial/Pytorch-UNet/releases/tag/v1.0 ) is available for the Carvana dataset. It can also be loaded from torch.hub:
```python
net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana')
```
The training was done with a 100% scale and bilinear upsampling.
2019-11-23 17:09:00 +00:00
## Tensorboard
2020-03-16 05:37:20 +00:00
You can visualize in real time the train and test losses, the weights and gradients, along with the model predictions with tensorboard:
2017-11-30 06:44:34 +00:00
2019-11-23 17:09:00 +00:00
`tensorboard --logdir=runs`
2017-11-30 17:50:25 +00:00
2020-03-16 05:37:20 +00:00
You can find a reference training run with the Caravana dataset on [TensorBoard.dev ](https://tensorboard.dev/experiment/1m1Ql50MSJixCbG1m9EcDQ/#scalars&_smoothingWeight=0.6 ) (only scalars are shown currently).
2017-11-30 17:50:25 +00:00
## Notes on memory
The model has be trained from scratch on a GTX970M 3GB.
Predicting images of 1918*1280 takes 1.5GB of memory.
2019-10-24 19:37:21 +00:00
Training takes much approximately 3GB, so if you are a few MB shy of memory, consider turning off all graphical displays.
2017-11-30 17:50:25 +00:00
This assumes you use bilinear up-sampling, and not transposed convolution in the model.
2019-10-24 19:37:21 +00:00
2020-07-24 00:04:38 +00:00
## Support
Personalized support for issues with this repository, or integrating with your own dataset, available on [xs:code ](https://xscode.com/milesial/Pytorch-UNet ).
2019-10-24 19:37:21 +00:00
---
Original paper by Olaf Ronneberger, Philipp Fischer, Thomas Brox: [https://arxiv.org/abs/1505.04597 ](https://arxiv.org/abs/1505.04597 )
![network architecture ](https://i.imgur.com/jeDVpqF.png )