mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 22:42:02 +00:00
4c0f0a7a7b
Former-commit-id: ff1ac0936c118d129bc8a8014958948d3b3883be
97 lines
4 KiB
Markdown
97 lines
4 KiB
Markdown
# UNet: semantic segmentation with PyTorch
|
|
|
|
![input and output for a random image in the test dataset](https://framapic.org/OcE8HlU6me61/KNTt8GFQzxDR.png)
|
|
|
|
|
|
Customized implementation of the [U-Net](https://arxiv.org/abs/1505.04597) in PyTorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge) from high definition images.
|
|
|
|
This model was trained from scratch with 5000 images (no data augmentation) and scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 (511 out of 735) on over 100k test images. This score could be improved with more training, data augmentation, fine tuning, playing with CRF post-processing, and applying more weights on the edges of the masks.
|
|
|
|
The Carvana data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data).
|
|
|
|
## Usage
|
|
**Note : Use Python 3**
|
|
### Prediction
|
|
|
|
You can easily test the output masks on your images via the CLI.
|
|
|
|
To predict a single image and save it:
|
|
|
|
`python predict.py -i image.jpg -o output.jpg`
|
|
|
|
To predict a multiple images and show them without saving them:
|
|
|
|
`python predict.py -i image1.jpg image2.jpg --viz --no-save`
|
|
|
|
```shell script
|
|
> python predict.py -h
|
|
usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...]
|
|
[--output INPUT [INPUT ...]] [--viz] [--no-save]
|
|
[--mask-threshold MASK_THRESHOLD] [--scale SCALE]
|
|
|
|
Predict masks from input images
|
|
|
|
optional arguments:
|
|
-h, --help show this help message and exit
|
|
--model FILE, -m FILE
|
|
Specify the file in which the model is stored
|
|
(default: MODEL.pth)
|
|
--input INPUT [INPUT ...], -i INPUT [INPUT ...]
|
|
filenames of input images (default: None)
|
|
--output INPUT [INPUT ...], -o INPUT [INPUT ...]
|
|
Filenames of ouput images (default: None)
|
|
--viz, -v Visualize the images as they are processed (default:
|
|
False)
|
|
--no-save, -n Do not save the output masks (default: False)
|
|
--mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD
|
|
Minimum probability value to consider a mask pixel
|
|
white (default: 0.5)
|
|
--scale SCALE, -s SCALE
|
|
Scale factor for the input images (default: 0.5)
|
|
```
|
|
You can specify which model file to use with `--model MODEL.pth`.
|
|
|
|
### Training
|
|
|
|
```shell script
|
|
> python train.py -h
|
|
usage: train.py [-h] [-e E] [-b [B]] [-l [LR]] [-f LOAD] [-s SCALE] [-v VAL]
|
|
|
|
Train the UNet on images and target masks
|
|
|
|
optional arguments:
|
|
-h, --help show this help message and exit
|
|
-e E, --epochs E Number of epochs (default: 5)
|
|
-b [B], --batch-size [B]
|
|
Batch size (default: 1)
|
|
-l [LR], --learning-rate [LR]
|
|
Learning rate (default: 0.1)
|
|
-f LOAD, --load LOAD Load model from a .pth file (default: False)
|
|
-s SCALE, --scale SCALE
|
|
Downscaling factor of the images (default: 0.5)
|
|
-v VAL, --validation VAL
|
|
Percent of the data that is used as validation (0-100)
|
|
(default: 15.0)
|
|
|
|
```
|
|
By default, the `scale` is 0.5, so if you wish to obtain better results (but use more memory), set it to 1.
|
|
|
|
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively.
|
|
|
|
|
|
## Dependencies
|
|
This package depends on [pydensecrf](https://github.com/lucasb-eyer/pydensecrf), available via `pip install`.
|
|
|
|
## Notes on memory
|
|
|
|
The model has be trained from scratch on a GTX970M 3GB.
|
|
Predicting images of 1918*1280 takes 1.5GB of memory.
|
|
Training takes much approximately 3GB, so if you are a few MB shy of memory, consider turning off all graphical displays.
|
|
This assumes you use bilinear up-sampling, and not transposed convolution in the model.
|
|
|
|
---
|
|
|
|
Original paper by Olaf Ronneberger, Philipp Fischer, Thomas Brox: [https://arxiv.org/abs/1505.04597](https://arxiv.org/abs/1505.04597)
|
|
|
|
![network architecture](https://i.imgur.com/jeDVpqF.png)
|