From a7270ca699d9981cea0a01a8e986a099c75fab7d Mon Sep 17 00:00:00 2001 From: milesial Date: Mon, 16 Aug 2021 06:01:29 +0200 Subject: [PATCH] README update Former-commit-id: c551c91786f26acc93abed0c043115ab8ef2fce0 --- README.md | 106 +++++++++++++++++++++++++++++++---------------------- predict.py | 2 +- 2 files changed, 63 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index 602bf65..99549cc 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,4 @@ -# UNet: semantic segmentation with PyTorch - -[![xscode](https://img.shields.io/badge/Available%20on-xs%3Acode-blue?style=?style=plastic&logo=appveyor&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAMAAACdt4HsAAAAGXRFWHRTb2Z0d2FyZQBBZG9iZSBJbWFnZVJlYWR5ccllPAAAAAZQTFRF////////VXz1bAAAAAJ0Uk5T/wDltzBKAAAAlUlEQVR42uzXSwqAMAwE0Mn9L+3Ggtgkk35QwcnSJo9S+yGwM9DCooCbgn4YrJ4CIPUcQF7/XSBbx2TEz4sAZ2q1RAECBAiYBlCtvwN+KiYAlG7UDGj59MViT9hOwEqAhYCtAsUZvL6I6W8c2wcbd+LIWSCHSTeSAAECngN4xxIDSK9f4B9t377Wd7H5Nt7/Xz8eAgwAvesLRjYYPuUAAAAASUVORK5CYII=)](https://xscode.com/milesial/Pytorch-UNet) +# U-Net: Semantic segmentation with PyTorch ![input and output for a random image in the test dataset](https://i.imgur.com/GD8FcB7.png) @@ -10,13 +8,51 @@ Customized implementation of the [U-Net](https://arxiv.org/abs/1505.04597) in Py This model was trained from scratch with 5000 images (no data augmentation) and scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 (511 out of 735) on over 100k test images. This score could be improved with more training, data augmentation, fine tuning, playing with CRF post-processing, and applying more weights on the edges of the masks. -The Carvana data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data). + ## Usage **Note : Use Python 3.6 or newer** + +### Docker + +A docker image containing the code and the dependencies is available on [DockerHub](https://hub.docker.com/repository/docker/milesial/unet). +You can jump in the container with ([docker >=19.03](https://docs.docker.com/get-docker/)): + +```shell script +docker run -it --rm --gpus all milesial/unet +``` + + + +### Training + +```shell script +> python train.py -h +usage: train.py [-h] [--epochs E] [--batch-size B] [--learning-rate LR] + [--load LOAD] [--scale SCALE] [--validation VAL] [--amp] + +Train the UNet on images and target masks + +optional arguments: + -h, --help show this help message and exit + --epochs E, -e E Number of epochs + --batch-size B, -b B Batch size + --learning-rate LR, -l LR + Learning rate + --load LOAD, -f LOAD Load model from a .pth file + --scale SCALE, -s SCALE + Downscaling factor of the images + --validation VAL, -v VAL + Percent of the data that is used as validation (0-100) + --amp Use mixed precision +``` + +By default, the `scale` is 0.5, so if you wish to obtain better results (but use more memory), set it to 1. +The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively. For Carvana, images are RGB and masks are black and white. + ### Prediction -After training your model and saving it to MODEL.pth, you can easily test the output masks on your images via the CLI. +After training your model and saving it to `MODEL.pth`, you can easily test the output masks on your images via the CLI. To predict a single image and save it: @@ -28,7 +64,7 @@ To predict a multiple images and show them without saving them: ```shell script > python predict.py -h -usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...] +usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...] [--output INPUT [INPUT ...]] [--viz] [--no-save] [--mask-threshold MASK_THRESHOLD] [--scale SCALE] @@ -38,48 +74,26 @@ optional arguments: -h, --help show this help message and exit --model FILE, -m FILE Specify the file in which the model is stored - (default: MODEL.pth) --input INPUT [INPUT ...], -i INPUT [INPUT ...] - filenames of input images (default: None) + Filenames of input images --output INPUT [INPUT ...], -o INPUT [INPUT ...] - Filenames of ouput images (default: None) - --viz, -v Visualize the images as they are processed (default: - False) - --no-save, -n Do not save the output masks (default: False) + Filenames of output images + --viz, -v Visualize the images as they are processed + --no-save, -n Do not save the output masks --mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD - Minimum probability value to consider a mask pixel - white (default: 0.5) + Minimum probability value to consider a mask pixel white --scale SCALE, -s SCALE - Scale factor for the input images (default: 0.5) + Scale factor for the input images ``` You can specify which model file to use with `--model MODEL.pth`. -### Training +### Weights & Biases -```shell script -> python train.py -h -usage: train.py [-h] [-e E] [-b [B]] [-l [LR]] [-f LOAD] [-s SCALE] [-v VAL] +The training progress can be visualized in real-time using [Weights & Biases](wandb.ai/). Loss curves, validation curves, weights and gradient histograms, as well as predicted masks are logged to the platform. -Train the UNet on images and target masks +When launching a training, a link will be printed in the console. Click on it to go to your dashboard. If you have an existing W&B account, you can link it + by setting the `WANDB_API_KEY` environment variable. -optional arguments: - -h, --help show this help message and exit - -e E, --epochs E Number of epochs (default: 5) - -b [B], --batch-size [B] - Batch size (default: 1) - -l [LR], --learning-rate [LR] - Learning rate (default: 0.1) - -f LOAD, --load LOAD Load model from a .pth file (default: False) - -s SCALE, --scale SCALE - Downscaling factor of the images (default: 0.5) - -v VAL, --validation VAL - Percent of the data that is used as validation (0-100) - (default: 15.0) - -``` -By default, the `scale` is 0.5, so if you wish to obtain better results (but use more memory), set it to 1. - -The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively. ### Pretrained model A [pretrained model](https://github.com/milesial/Pytorch-UNet/releases/tag/v1.0) is available for the Carvana dataset. It can also be loaded from torch.hub: @@ -89,12 +103,14 @@ net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana') ``` The training was done with a 100% scale and bilinear upsampling. -## Tensorboard -You can visualize in real time the train and test losses, the weights and gradients, along with the model predictions with tensorboard: +## Data +The Carvana data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data). -`tensorboard --logdir=runs` +You can also download it using your Kaggle API key with: -You can find a reference training run with the Caravana dataset on [TensorBoard.dev](https://tensorboard.dev/experiment/1m1Ql50MSJixCbG1m9EcDQ/#scalars&_smoothingWeight=0.6) (only scalars are shown currently). +```shell script +bash download_data.sh +``` ## Notes on memory @@ -103,9 +119,11 @@ Predicting images of 1918*1280 takes 1.5GB of memory. Training takes much approximately 3GB, so if you are a few MB shy of memory, consider turning off all graphical displays. This assumes you use bilinear up-sampling, and not transposed convolution in the model. -## Support +## Convergence + +See a reference training run with the Caravana dataset on [TensorBoard.dev](https://tensorboard.dev/experiment/1m1Ql50MSJixCbG1m9EcDQ/#scalars&_smoothingWeight=0.6) (only scalars are shown currently). + -Personalized support for issues with this repository, or integrating with your own dataset, available on [xs:code](https://xscode.com/milesial/Pytorch-UNet). --- diff --git a/predict.py b/predict.py index 8e158db..348b27c 100755 --- a/predict.py +++ b/predict.py @@ -74,7 +74,7 @@ def mask_to_image(mask: np.ndarray): if mask.ndim == 2: return Image.fromarray((mask * 255).astype(np.uint8)) elif mask.ndim == 3: - return Image.fromarray((np.argmax(mask, dim=0) * 255 / mask.shape[0]).astype(np.uint8)) + return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8)) if __name__ == '__main__':