README update
Former-commit-id: c551c91786f26acc93abed0c043115ab8ef2fce0
This commit is contained in:
parent
c5257f8367
commit
a7270ca699
106
README.md
106
README.md
|
@ -1,6 +1,4 @@
|
||||||
# UNet: semantic segmentation with PyTorch
|
# U-Net: Semantic segmentation with PyTorch
|
||||||
|
|
||||||
[![xscode](https://img.shields.io/badge/Available%20on-xs%3Acode-blue?style=?style=plastic&logo=appveyor&logo=)](https://xscode.com/milesial/Pytorch-UNet)
|
|
||||||
|
|
||||||
|
|
||||||
![input and output for a random image in the test dataset](https://i.imgur.com/GD8FcB7.png)
|
![input and output for a random image in the test dataset](https://i.imgur.com/GD8FcB7.png)
|
||||||
|
@ -10,13 +8,51 @@ Customized implementation of the [U-Net](https://arxiv.org/abs/1505.04597) in Py
|
||||||
|
|
||||||
This model was trained from scratch with 5000 images (no data augmentation) and scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 (511 out of 735) on over 100k test images. This score could be improved with more training, data augmentation, fine tuning, playing with CRF post-processing, and applying more weights on the edges of the masks.
|
This model was trained from scratch with 5000 images (no data augmentation) and scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 (511 out of 735) on over 100k test images. This score could be improved with more training, data augmentation, fine tuning, playing with CRF post-processing, and applying more weights on the edges of the masks.
|
||||||
|
|
||||||
The Carvana data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data).
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
**Note : Use Python 3.6 or newer**
|
**Note : Use Python 3.6 or newer**
|
||||||
|
|
||||||
|
### Docker
|
||||||
|
|
||||||
|
A docker image containing the code and the dependencies is available on [DockerHub](https://hub.docker.com/repository/docker/milesial/unet).
|
||||||
|
You can jump in the container with ([docker >=19.03](https://docs.docker.com/get-docker/)):
|
||||||
|
|
||||||
|
```shell script
|
||||||
|
docker run -it --rm --gpus all milesial/unet
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Training
|
||||||
|
|
||||||
|
```shell script
|
||||||
|
> python train.py -h
|
||||||
|
usage: train.py [-h] [--epochs E] [--batch-size B] [--learning-rate LR]
|
||||||
|
[--load LOAD] [--scale SCALE] [--validation VAL] [--amp]
|
||||||
|
|
||||||
|
Train the UNet on images and target masks
|
||||||
|
|
||||||
|
optional arguments:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
--epochs E, -e E Number of epochs
|
||||||
|
--batch-size B, -b B Batch size
|
||||||
|
--learning-rate LR, -l LR
|
||||||
|
Learning rate
|
||||||
|
--load LOAD, -f LOAD Load model from a .pth file
|
||||||
|
--scale SCALE, -s SCALE
|
||||||
|
Downscaling factor of the images
|
||||||
|
--validation VAL, -v VAL
|
||||||
|
Percent of the data that is used as validation (0-100)
|
||||||
|
--amp Use mixed precision
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, the `scale` is 0.5, so if you wish to obtain better results (but use more memory), set it to 1.
|
||||||
|
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively. For Carvana, images are RGB and masks are black and white.
|
||||||
|
|
||||||
### Prediction
|
### Prediction
|
||||||
|
|
||||||
After training your model and saving it to MODEL.pth, you can easily test the output masks on your images via the CLI.
|
After training your model and saving it to `MODEL.pth`, you can easily test the output masks on your images via the CLI.
|
||||||
|
|
||||||
To predict a single image and save it:
|
To predict a single image and save it:
|
||||||
|
|
||||||
|
@ -28,7 +64,7 @@ To predict a multiple images and show them without saving them:
|
||||||
|
|
||||||
```shell script
|
```shell script
|
||||||
> python predict.py -h
|
> python predict.py -h
|
||||||
usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...]
|
usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...]
|
||||||
[--output INPUT [INPUT ...]] [--viz] [--no-save]
|
[--output INPUT [INPUT ...]] [--viz] [--no-save]
|
||||||
[--mask-threshold MASK_THRESHOLD] [--scale SCALE]
|
[--mask-threshold MASK_THRESHOLD] [--scale SCALE]
|
||||||
|
|
||||||
|
@ -38,48 +74,26 @@ optional arguments:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
--model FILE, -m FILE
|
--model FILE, -m FILE
|
||||||
Specify the file in which the model is stored
|
Specify the file in which the model is stored
|
||||||
(default: MODEL.pth)
|
|
||||||
--input INPUT [INPUT ...], -i INPUT [INPUT ...]
|
--input INPUT [INPUT ...], -i INPUT [INPUT ...]
|
||||||
filenames of input images (default: None)
|
Filenames of input images
|
||||||
--output INPUT [INPUT ...], -o INPUT [INPUT ...]
|
--output INPUT [INPUT ...], -o INPUT [INPUT ...]
|
||||||
Filenames of ouput images (default: None)
|
Filenames of output images
|
||||||
--viz, -v Visualize the images as they are processed (default:
|
--viz, -v Visualize the images as they are processed
|
||||||
False)
|
--no-save, -n Do not save the output masks
|
||||||
--no-save, -n Do not save the output masks (default: False)
|
|
||||||
--mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD
|
--mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD
|
||||||
Minimum probability value to consider a mask pixel
|
Minimum probability value to consider a mask pixel white
|
||||||
white (default: 0.5)
|
|
||||||
--scale SCALE, -s SCALE
|
--scale SCALE, -s SCALE
|
||||||
Scale factor for the input images (default: 0.5)
|
Scale factor for the input images
|
||||||
```
|
```
|
||||||
You can specify which model file to use with `--model MODEL.pth`.
|
You can specify which model file to use with `--model MODEL.pth`.
|
||||||
|
|
||||||
### Training
|
### Weights & Biases
|
||||||
|
|
||||||
```shell script
|
The training progress can be visualized in real-time using [Weights & Biases](wandb.ai/). Loss curves, validation curves, weights and gradient histograms, as well as predicted masks are logged to the platform.
|
||||||
> python train.py -h
|
|
||||||
usage: train.py [-h] [-e E] [-b [B]] [-l [LR]] [-f LOAD] [-s SCALE] [-v VAL]
|
|
||||||
|
|
||||||
Train the UNet on images and target masks
|
When launching a training, a link will be printed in the console. Click on it to go to your dashboard. If you have an existing W&B account, you can link it
|
||||||
|
by setting the `WANDB_API_KEY` environment variable.
|
||||||
|
|
||||||
optional arguments:
|
|
||||||
-h, --help show this help message and exit
|
|
||||||
-e E, --epochs E Number of epochs (default: 5)
|
|
||||||
-b [B], --batch-size [B]
|
|
||||||
Batch size (default: 1)
|
|
||||||
-l [LR], --learning-rate [LR]
|
|
||||||
Learning rate (default: 0.1)
|
|
||||||
-f LOAD, --load LOAD Load model from a .pth file (default: False)
|
|
||||||
-s SCALE, --scale SCALE
|
|
||||||
Downscaling factor of the images (default: 0.5)
|
|
||||||
-v VAL, --validation VAL
|
|
||||||
Percent of the data that is used as validation (0-100)
|
|
||||||
(default: 15.0)
|
|
||||||
|
|
||||||
```
|
|
||||||
By default, the `scale` is 0.5, so if you wish to obtain better results (but use more memory), set it to 1.
|
|
||||||
|
|
||||||
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively.
|
|
||||||
|
|
||||||
### Pretrained model
|
### Pretrained model
|
||||||
A [pretrained model](https://github.com/milesial/Pytorch-UNet/releases/tag/v1.0) is available for the Carvana dataset. It can also be loaded from torch.hub:
|
A [pretrained model](https://github.com/milesial/Pytorch-UNet/releases/tag/v1.0) is available for the Carvana dataset. It can also be loaded from torch.hub:
|
||||||
|
@ -89,12 +103,14 @@ net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana')
|
||||||
```
|
```
|
||||||
The training was done with a 100% scale and bilinear upsampling.
|
The training was done with a 100% scale and bilinear upsampling.
|
||||||
|
|
||||||
## Tensorboard
|
## Data
|
||||||
You can visualize in real time the train and test losses, the weights and gradients, along with the model predictions with tensorboard:
|
The Carvana data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data).
|
||||||
|
|
||||||
`tensorboard --logdir=runs`
|
You can also download it using your Kaggle API key with:
|
||||||
|
|
||||||
You can find a reference training run with the Caravana dataset on [TensorBoard.dev](https://tensorboard.dev/experiment/1m1Ql50MSJixCbG1m9EcDQ/#scalars&_smoothingWeight=0.6) (only scalars are shown currently).
|
```shell script
|
||||||
|
bash download_data.sh <username> <apikey>
|
||||||
|
```
|
||||||
|
|
||||||
## Notes on memory
|
## Notes on memory
|
||||||
|
|
||||||
|
@ -103,9 +119,11 @@ Predicting images of 1918*1280 takes 1.5GB of memory.
|
||||||
Training takes much approximately 3GB, so if you are a few MB shy of memory, consider turning off all graphical displays.
|
Training takes much approximately 3GB, so if you are a few MB shy of memory, consider turning off all graphical displays.
|
||||||
This assumes you use bilinear up-sampling, and not transposed convolution in the model.
|
This assumes you use bilinear up-sampling, and not transposed convolution in the model.
|
||||||
|
|
||||||
## Support
|
## Convergence
|
||||||
|
|
||||||
|
See a reference training run with the Caravana dataset on [TensorBoard.dev](https://tensorboard.dev/experiment/1m1Ql50MSJixCbG1m9EcDQ/#scalars&_smoothingWeight=0.6) (only scalars are shown currently).
|
||||||
|
|
||||||
|
|
||||||
Personalized support for issues with this repository, or integrating with your own dataset, available on [xs:code](https://xscode.com/milesial/Pytorch-UNet).
|
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
|
@ -74,7 +74,7 @@ def mask_to_image(mask: np.ndarray):
|
||||||
if mask.ndim == 2:
|
if mask.ndim == 2:
|
||||||
return Image.fromarray((mask * 255).astype(np.uint8))
|
return Image.fromarray((mask * 255).astype(np.uint8))
|
||||||
elif mask.ndim == 3:
|
elif mask.ndim == 3:
|
||||||
return Image.fromarray((np.argmax(mask, dim=0) * 255 / mask.shape[0]).astype(np.uint8))
|
return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in a new issue