Cleaned code, added image to README.md
Former-commit-id: 3acf1ff8dadb74e95786fb6ddcf1a90de63f5079
This commit is contained in:
parent
8b614c3e31
commit
2d427db832
10
README.md
10
README.md
|
@ -1,5 +1,8 @@
|
||||||
# Pytorch-UNet
|
# Pytorch-UNet
|
||||||
Customized implementation of the [U-Net](https://arxiv.org/pdf/1505.04597.pdf) in Pytorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge).
|
![input and output for a random image in the test dataset](https://framapic.org/YqBT4lbLrcfc/kQcSxYDv1Pfk.png)
|
||||||
|
|
||||||
|
|
||||||
|
Customized implementation of the [U-Net](https://arxiv.org/pdf/1505.04597.pdf) in Pytorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge), with only 1 output class, from a high definition image.
|
||||||
|
|
||||||
This model scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 (511 out of 735), which is bad but could be improved with more training, data augmentation, fine tuning, and playing with CRF post-processing.
|
This model scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 (511 out of 735), which is bad but could be improved with more training, data augmentation, fine tuning, and playing with CRF post-processing.
|
||||||
|
|
||||||
|
@ -7,7 +10,7 @@ The model used for the last submission is stored in the `MODEL.pth` file, if you
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
### Prediction
|
### Prediction
|
||||||
|
|
||||||
You can easily test the output masks on your images via the CLI.
|
You can easily test the output masks on your images via the CLI.
|
||||||
To see all options:
|
To see all options:
|
||||||
|
@ -27,3 +30,6 @@ You can specify which model file to use with `--model MODEL.pth`.
|
||||||
|
|
||||||
## Warning
|
## Warning
|
||||||
In order to process the image, it is splitted into two squares (a left on and a right one), and each square is passed into the net. The two square masks are then merged again to produce the final image. As a consequence, the height of the image must be strictly superior than half the width. Make sure the width is even too.
|
In order to process the image, it is splitted into two squares (a left on and a right one), and each square is passed into the net. The two square masks are then merged again to produce the final image. As a consequence, the height of the image must be strictly superior than half the width. Make sure the width is even too.
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
This package depends on [pydensecrf](https://github.com/lucasb-eyer/pydensecrf), available via `pip install`.
|
||||||
|
|
4
eval.py
4
eval.py
|
@ -2,10 +2,10 @@ import torch
|
||||||
from myloss import dice_coeff
|
from myloss import dice_coeff
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from data_vis import plot_img_mask
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from crf import dense_crf
|
|
||||||
|
from utils import dense_crf, plot_img_mask
|
||||||
|
|
||||||
|
|
||||||
def eval_net(net, dataset, gpu=False):
|
def eval_net(net, dataset, gpu=False):
|
||||||
|
|
1
main.py
1
main.py
|
@ -14,7 +14,6 @@ import PIL
|
||||||
import os
|
import os
|
||||||
|
|
||||||
#data visualization
|
#data visualization
|
||||||
from data_vis import plot_img_mask
|
|
||||||
from utils import *
|
from utils import *
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from utils import *
|
from utils import *
|
||||||
from crf import dense_crf
|
|
||||||
|
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
|
|
||||||
|
|
5
train.py
5
train.py
|
@ -3,9 +3,7 @@ import torch.backends.cudnn as cudnn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from load import *
|
from utils import *
|
||||||
from data_vis import *
|
|
||||||
from utils import split_train_val, batch
|
|
||||||
from myloss import DiceLoss
|
from myloss import DiceLoss
|
||||||
from eval import eval_net
|
from eval import eval_net
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
|
@ -14,7 +12,6 @@ from torch import optim
|
||||||
from optparse import OptionParser
|
from optparse import OptionParser
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import argparse
|
|
||||||
|
|
||||||
|
|
||||||
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
|
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
|
||||||
|
|
4
utils/__init__.py
Normal file
4
utils/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
from .crf import *
|
||||||
|
from .load import *
|
||||||
|
from .utils import *
|
||||||
|
from .data_vis import *
|
|
@ -15,7 +15,6 @@ def dense_crf(img, output_probs):
|
||||||
U = np.ascontiguousarray(U)
|
U = np.ascontiguousarray(U)
|
||||||
img = np.ascontiguousarray(img)
|
img = np.ascontiguousarray(img)
|
||||||
|
|
||||||
|
|
||||||
d.setUnaryEnergy(U)
|
d.setUnaryEnergy(U)
|
||||||
|
|
||||||
d.addPairwiseGaussian(sxy=20, compat=3)
|
d.addPairwiseGaussian(sxy=20, compat=3)
|
|
@ -8,7 +8,7 @@ import numpy as np
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from utils import resize_and_crop, get_square, normalize
|
from .utils import resize_and_crop, get_square, normalize
|
||||||
|
|
||||||
|
|
||||||
def get_ids(dir):
|
def get_ids(dir):
|
Loading…
Reference in a new issue