diff --git a/README.md b/README.md index 2fa6e10..e25e76c 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ This model scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rens The model used for the last submission is stored in the `MODEL.pth` file, if you wish to play with it. The data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data). ## Usage + ### Prediction You can easily test the output masks on your images via the CLI. @@ -13,13 +14,16 @@ To see all options: `python predict.py -h` To predict a single image and save it: -`python predict.py -i image.jpg -o ouput.jpg + +`python predict.py -i image.jpg -o ouput.jpg` To predict a multiple images and show them without saving them: + `python predict.py -i image1.jpg image2.jpg --viz --no-save` You can use the cpu-only version with `--cpu`. + You can specify which model file to use with `--model MODEL.pth`. -## Note -The code and the overall project architecture is a big mess for now, as I left it abandoned when the challenge finished. I will clean it SoonTM. +## Warning +In order to process the image, it is splitted into two squares (a left on and a right one), and each square is passed into the net. The two square masks are then merged again to produce the final image. As a consequence, the height of the image must be strictly superior than half the width. Make sure the width is even too. diff --git a/predict.py b/predict.py index 7a4c792..9a2ded1 100644 --- a/predict.py +++ b/predict.py @@ -12,6 +12,7 @@ from crf import dense_crf from unet import UNet + def predict_img(net, full_img, gpu=False): img = resize_and_crop(full_img) @@ -39,7 +40,7 @@ def predict_img(net, full_img, gpu=False): y_l = F.upsample_bilinear(y_l, scale_factor=2).data[0][0].cpu().numpy() y_r = F.upsample_bilinear(y_r, scale_factor=2).data[0][0].cpu().numpy() - y = merge_masks(y_l, y_r, 1918) + y = merge_masks(y_l, y_r, full_img.size[0]) yy = dense_crf(np.array(full_img).astype(np.uint8), y) return yy > 0.5 diff --git a/submit.py b/submit.py index c82e55d..c549ce2 100644 --- a/submit.py +++ b/submit.py @@ -1,3 +1,4 @@ +# used to predict all test images and encode results in a csv file import os from PIL import Image