Cleaned code, added image to README.md

Former-commit-id: 3acf1ff8dadb74e95786fb6ddcf1a90de63f5079
This commit is contained in:
milesial 2017-11-30 07:44:34 +01:00
parent 8b614c3e31
commit 2d427db832
10 changed files with 16 additions and 12 deletions

View file

@ -1,5 +1,8 @@
# Pytorch-UNet
Customized implementation of the [U-Net](https://arxiv.org/pdf/1505.04597.pdf) in Pytorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge).
![input and output for a random image in the test dataset](https://framapic.org/YqBT4lbLrcfc/kQcSxYDv1Pfk.png)
Customized implementation of the [U-Net](https://arxiv.org/pdf/1505.04597.pdf) in Pytorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge), with only 1 output class, from a high definition image.
This model scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 (511 out of 735), which is bad but could be improved with more training, data augmentation, fine tuning, and playing with CRF post-processing.
@ -7,7 +10,7 @@ The model used for the last submission is stored in the `MODEL.pth` file, if you
## Usage
### Prediction
### Prediction
You can easily test the output masks on your images via the CLI.
To see all options:
@ -27,3 +30,6 @@ You can specify which model file to use with `--model MODEL.pth`.
## Warning
In order to process the image, it is splitted into two squares (a left on and a right one), and each square is passed into the net. The two square masks are then merged again to produce the final image. As a consequence, the height of the image must be strictly superior than half the width. Make sure the width is even too.
## Dependencies
This package depends on [pydensecrf](https://github.com/lucasb-eyer/pydensecrf), available via `pip install`.

View file

@ -2,10 +2,10 @@ import torch
from myloss import dice_coeff
import numpy as np
from torch.autograd import Variable
from data_vis import plot_img_mask
import matplotlib.pyplot as plt
import torch.nn.functional as F
from crf import dense_crf
from utils import dense_crf, plot_img_mask
def eval_net(net, dataset, gpu=False):

View file

@ -14,7 +14,6 @@ import PIL
import os
#data visualization
from data_vis import plot_img_mask
from utils import *
import matplotlib.pyplot as plt

View file

@ -8,7 +8,6 @@ import argparse
import os
from utils import *
from crf import dense_crf
from unet import UNet

View file

@ -3,9 +3,7 @@ import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.nn as nn
from load import *
from data_vis import *
from utils import split_train_val, batch
from utils import *
from myloss import DiceLoss
from eval import eval_net
from unet import UNet
@ -14,7 +12,6 @@ from torch import optim
from optparse import OptionParser
import sys
import os
import argparse
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,

4
utils/__init__.py Normal file
View file

@ -0,0 +1,4 @@
from .crf import *
from .load import *
from .utils import *
from .data_vis import *

View file

@ -15,7 +15,6 @@ def dense_crf(img, output_probs):
U = np.ascontiguousarray(U)
img = np.ascontiguousarray(img)
d.setUnaryEnergy(U)
d.addPairwiseGaussian(sxy=20, compat=3)

View file

@ -8,7 +8,7 @@ import numpy as np
from PIL import Image
from functools import partial
from utils import resize_and_crop, get_square, normalize
from .utils import resize_and_crop, get_square, normalize
def get_ids(dir):