From 2d427db832e755ef58dfa4dffd984a3534cafdfb Mon Sep 17 00:00:00 2001 From: milesial Date: Thu, 30 Nov 2017 07:44:34 +0100 Subject: [PATCH] Cleaned code, added image to README.md Former-commit-id: 3acf1ff8dadb74e95786fb6ddcf1a90de63f5079 --- README.md | 10 ++++++++-- eval.py | 4 ++-- main.py | 1 - predict.py | 1 - train.py | 5 +---- utils/__init__.py | 4 ++++ crf.py => utils/crf.py | 1 - data_vis.py => utils/data_vis.py | 0 load.py => utils/load.py | 2 +- utils.py => utils/utils.py | 0 10 files changed, 16 insertions(+), 12 deletions(-) create mode 100644 utils/__init__.py rename crf.py => utils/crf.py (99%) rename data_vis.py => utils/data_vis.py (100%) rename load.py => utils/load.py (95%) rename utils.py => utils/utils.py (100%) diff --git a/README.md b/README.md index e25e76c..131cb62 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # Pytorch-UNet -Customized implementation of the [U-Net](https://arxiv.org/pdf/1505.04597.pdf) in Pytorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge). +![input and output for a random image in the test dataset](https://framapic.org/YqBT4lbLrcfc/kQcSxYDv1Pfk.png) + + +Customized implementation of the [U-Net](https://arxiv.org/pdf/1505.04597.pdf) in Pytorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge), with only 1 output class, from a high definition image. This model scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 (511 out of 735), which is bad but could be improved with more training, data augmentation, fine tuning, and playing with CRF post-processing. @@ -7,7 +10,7 @@ The model used for the last submission is stored in the `MODEL.pth` file, if you ## Usage -### Prediction +### Prediction You can easily test the output masks on your images via the CLI. To see all options: @@ -27,3 +30,6 @@ You can specify which model file to use with `--model MODEL.pth`. ## Warning In order to process the image, it is splitted into two squares (a left on and a right one), and each square is passed into the net. The two square masks are then merged again to produce the final image. As a consequence, the height of the image must be strictly superior than half the width. Make sure the width is even too. + +## Dependencies +This package depends on [pydensecrf](https://github.com/lucasb-eyer/pydensecrf), available via `pip install`. diff --git a/eval.py b/eval.py index ee6d666..fb2681d 100644 --- a/eval.py +++ b/eval.py @@ -2,10 +2,10 @@ import torch from myloss import dice_coeff import numpy as np from torch.autograd import Variable -from data_vis import plot_img_mask import matplotlib.pyplot as plt import torch.nn.functional as F -from crf import dense_crf + +from utils import dense_crf, plot_img_mask def eval_net(net, dataset, gpu=False): diff --git a/main.py b/main.py index 3e27096..06ede62 100644 --- a/main.py +++ b/main.py @@ -14,7 +14,6 @@ import PIL import os #data visualization -from data_vis import plot_img_mask from utils import * import matplotlib.pyplot as plt diff --git a/predict.py b/predict.py index 9a2ded1..e173a7d 100644 --- a/predict.py +++ b/predict.py @@ -8,7 +8,6 @@ import argparse import os from utils import * -from crf import dense_crf from unet import UNet diff --git a/train.py b/train.py index 5b3f325..0ef7986 100644 --- a/train.py +++ b/train.py @@ -3,9 +3,7 @@ import torch.backends.cudnn as cudnn import torch.nn.functional as F import torch.nn as nn -from load import * -from data_vis import * -from utils import split_train_val, batch +from utils import * from myloss import DiceLoss from eval import eval_net from unet import UNet @@ -14,7 +12,6 @@ from torch import optim from optparse import OptionParser import sys import os -import argparse def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05, diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..54e2c6f --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,4 @@ +from .crf import * +from .load import * +from .utils import * +from .data_vis import * diff --git a/crf.py b/utils/crf.py similarity index 99% rename from crf.py rename to utils/crf.py index 713a47b..5ee718f 100644 --- a/crf.py +++ b/utils/crf.py @@ -15,7 +15,6 @@ def dense_crf(img, output_probs): U = np.ascontiguousarray(U) img = np.ascontiguousarray(img) - d.setUnaryEnergy(U) d.addPairwiseGaussian(sxy=20, compat=3) diff --git a/data_vis.py b/utils/data_vis.py similarity index 100% rename from data_vis.py rename to utils/data_vis.py diff --git a/load.py b/utils/load.py similarity index 95% rename from load.py rename to utils/load.py index 9b847d0..41ff3b4 100644 --- a/load.py +++ b/utils/load.py @@ -8,7 +8,7 @@ import numpy as np from PIL import Image from functools import partial -from utils import resize_and_crop, get_square, normalize +from .utils import resize_and_crop, get_square, normalize def get_ids(dir): diff --git a/utils.py b/utils/utils.py similarity index 100% rename from utils.py rename to utils/utils.py