Cleaned code, added image to README.md

Former-commit-id: 3acf1ff8dadb74e95786fb6ddcf1a90de63f5079
This commit is contained in:
milesial 2017-11-30 07:44:34 +01:00
parent 8b614c3e31
commit 2d427db832
10 changed files with 16 additions and 12 deletions

View file

@ -1,5 +1,8 @@
# Pytorch-UNet # Pytorch-UNet
Customized implementation of the [U-Net](https://arxiv.org/pdf/1505.04597.pdf) in Pytorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge). ![input and output for a random image in the test dataset](https://framapic.org/YqBT4lbLrcfc/kQcSxYDv1Pfk.png)
Customized implementation of the [U-Net](https://arxiv.org/pdf/1505.04597.pdf) in Pytorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge), with only 1 output class, from a high definition image.
This model scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 (511 out of 735), which is bad but could be improved with more training, data augmentation, fine tuning, and playing with CRF post-processing. This model scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 (511 out of 735), which is bad but could be improved with more training, data augmentation, fine tuning, and playing with CRF post-processing.
@ -7,7 +10,7 @@ The model used for the last submission is stored in the `MODEL.pth` file, if you
## Usage ## Usage
### Prediction ### Prediction
You can easily test the output masks on your images via the CLI. You can easily test the output masks on your images via the CLI.
To see all options: To see all options:
@ -27,3 +30,6 @@ You can specify which model file to use with `--model MODEL.pth`.
## Warning ## Warning
In order to process the image, it is splitted into two squares (a left on and a right one), and each square is passed into the net. The two square masks are then merged again to produce the final image. As a consequence, the height of the image must be strictly superior than half the width. Make sure the width is even too. In order to process the image, it is splitted into two squares (a left on and a right one), and each square is passed into the net. The two square masks are then merged again to produce the final image. As a consequence, the height of the image must be strictly superior than half the width. Make sure the width is even too.
## Dependencies
This package depends on [pydensecrf](https://github.com/lucasb-eyer/pydensecrf), available via `pip install`.

View file

@ -2,10 +2,10 @@ import torch
from myloss import dice_coeff from myloss import dice_coeff
import numpy as np import numpy as np
from torch.autograd import Variable from torch.autograd import Variable
from data_vis import plot_img_mask
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch.nn.functional as F import torch.nn.functional as F
from crf import dense_crf
from utils import dense_crf, plot_img_mask
def eval_net(net, dataset, gpu=False): def eval_net(net, dataset, gpu=False):

View file

@ -14,7 +14,6 @@ import PIL
import os import os
#data visualization #data visualization
from data_vis import plot_img_mask
from utils import * from utils import *
import matplotlib.pyplot as plt import matplotlib.pyplot as plt

View file

@ -8,7 +8,6 @@ import argparse
import os import os
from utils import * from utils import *
from crf import dense_crf
from unet import UNet from unet import UNet

View file

@ -3,9 +3,7 @@ import torch.backends.cudnn as cudnn
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
from load import * from utils import *
from data_vis import *
from utils import split_train_val, batch
from myloss import DiceLoss from myloss import DiceLoss
from eval import eval_net from eval import eval_net
from unet import UNet from unet import UNet
@ -14,7 +12,6 @@ from torch import optim
from optparse import OptionParser from optparse import OptionParser
import sys import sys
import os import os
import argparse
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05, def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,

4
utils/__init__.py Normal file
View file

@ -0,0 +1,4 @@
from .crf import *
from .load import *
from .utils import *
from .data_vis import *

View file

@ -15,7 +15,6 @@ def dense_crf(img, output_probs):
U = np.ascontiguousarray(U) U = np.ascontiguousarray(U)
img = np.ascontiguousarray(img) img = np.ascontiguousarray(img)
d.setUnaryEnergy(U) d.setUnaryEnergy(U)
d.addPairwiseGaussian(sxy=20, compat=3) d.addPairwiseGaussian(sxy=20, compat=3)

View file

@ -8,7 +8,7 @@ import numpy as np
from PIL import Image from PIL import Image
from functools import partial from functools import partial
from utils import resize_and_crop, get_square, normalize from .utils import resize_and_crop, get_square, normalize
def get_ids(dir): def get_ids(dir):